KingNish commited on
Commit
2f818fd
1 Parent(s): e392e21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -68
app.py CHANGED
@@ -13,17 +13,14 @@ import tempfile
13
  import uuid
14
  import time
15
  from concurrent.futures import ThreadPoolExecutor
16
- import asyncio
17
 
18
  torch.set_float32_matmul_precision("medium")
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load both BiRefNet models
22
- birefnet = AutoModelForImageSegmentation.from_pretrained(
23
- "ZhengPeng7/BiRefNet", trust_remote_code=True)
24
  birefnet.to(device)
25
- birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
26
- "ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
27
  birefnet_lite.to(device)
28
 
29
  transform_image = transforms.Compose([
@@ -32,74 +29,77 @@ transform_image = transforms.Compose([
32
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
  ])
34
 
35
- # Function to process a single frame asynchronously
36
- async def process_frame_async(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
37
- pil_image = Image.fromarray(frame)
38
- if bg_type == "Color":
39
- processed_image = process(pil_image, color, fast_mode)
40
- elif bg_type == "Image":
41
- processed_image = process(pil_image, bg, fast_mode)
42
- elif bg_type == "Video":
43
- background_frame = background_frames[bg_frame_index % len(background_frames)]
44
- bg_frame_index += 1
45
- background_image = Image.fromarray(background_frame)
46
- processed_image = process(pil_image, background_image, fast_mode)
47
- else:
48
- processed_image = pil_image # Default to original image if no background is selected
49
- return np.array(processed_image), bg_frame_index
 
 
 
 
50
 
51
  @spaces.GPU
52
- async def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=6):
53
- start_time = time.time() # Start the timer
54
-
55
- video = mp.VideoFileClip(vid)
56
- if fps == 0:
57
- fps = video.fps
58
-
59
- audio = video.audio
60
- frames = list(video.iter_frames(fps=fps))
61
-
62
- processed_frames = []
63
- yield gr.update(visible=True), gr.update(visible=False), f"Processing started... Elapsed time: 0 seconds"
64
-
65
- if bg_type == "Video":
66
- background_video = mp.VideoFileClip(bg_video)
67
- if background_video.duration < video.duration:
68
- if video_handling == "slow_down":
69
- background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration)
70
- else: # video_handling == "loop"
71
- background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
72
- background_frames = list(background_video.iter_frames(fps=fps))
73
- else:
74
- background_frames = None
75
-
76
- bg_frame_index = 0
77
 
78
- # Use ThreadPoolExecutor for parallel processing with specified max_workers
79
- loop = asyncio.get_event_loop()
80
- tasks = [
81
- loop.run_in_executor(
82
- None, process_frame_async, frames[i], bg_type, bg_image, fast_mode, bg_frame_index, background_frames, color
83
- )
84
- for i in range(len(frames))
85
- ]
86
-
87
- for future in asyncio.as_completed(tasks):
88
- result, bg_frame_index = await future
89
- processed_frames.append(result)
 
 
 
90
  elapsed_time = time.time() - start_time
91
- yield result, None, f"Processing frame {len(processed_frames)}... Elapsed time: {elapsed_time:.2f} seconds"
 
92
 
93
- processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
94
- processed_video = processed_video.set_audio(audio)
95
-
96
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
97
- temp_filepath = temp_file.name
98
- processed_video.write_videofile(temp_filepath, codec="libx264")
99
-
100
- elapsed_time = time.time() - start_time
101
- yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
102
- yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
103
 
104
  def process(image, bg, fast_mode=False):
105
  image_size = image.size
 
13
  import uuid
14
  import time
15
  from concurrent.futures import ThreadPoolExecutor
 
16
 
17
  torch.set_float32_matmul_precision("medium")
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Load both BiRefNet models
21
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
 
22
  birefnet.to(device)
23
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
 
24
  birefnet_lite.to(device)
25
 
26
  transform_image = transforms.Compose([
 
29
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30
  ])
31
 
32
+ # Function to process a single frame
33
+ def process_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
34
+ try:
35
+ pil_image = Image.fromarray(frame)
36
+ if bg_type == "Color":
37
+ processed_image = process(pil_image, color, fast_mode)
38
+ elif bg_type == "Image":
39
+ processed_image = process(pil_image, bg, fast_mode)
40
+ elif bg_type == "Video":
41
+ background_frame = background_frames[bg_frame_index % len(background_frames)]
42
+ bg_frame_index += 1
43
+ background_image = Image.fromarray(background_frame)
44
+ processed_image = process(pil_image, background_image, fast_mode)
45
+ else:
46
+ processed_image = pil_image # Default to original image if no background is selected
47
+ return np.array(processed_image), bg_frame_index
48
+ except Exception as e:
49
+ print(f"Error processing frame: {e}")
50
+ return frame, bg_frame_index
51
 
52
  @spaces.GPU
53
+ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=6):
54
+ try:
55
+ start_time = time.time() # Start the timer
56
+ video = mp.VideoFileClip(vid)
57
+ if fps == 0:
58
+ fps = video.fps
59
+
60
+ audio = video.audio
61
+ frames = list(video.iter_frames(fps=fps))
62
+
63
+ processed_frames = []
64
+ yield gr.update(visible=True), gr.update(visible=False), f"Processing started... Elapsed time: 0 seconds"
65
+
66
+ if bg_type == "Video":
67
+ background_video = mp.VideoFileClip(bg_video)
68
+ if background_video.duration < video.duration:
69
+ if video_handling == "slow_down":
70
+ background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration)
71
+ else: # video_handling == "loop"
72
+ background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
73
+ background_frames = list(background_video.iter_frames(fps=fps))
74
+ else:
75
+ background_frames = None
76
+
77
+ bg_frame_index = 0 # Initialize background frame index
78
 
79
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
80
+ futures = [executor.submit(process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index, background_frames, color) for i in range(len(frames))]
81
+ for future in futures:
82
+ result, bg_frame_index = future.result()
83
+ processed_frames.append(result)
84
+ elapsed_time = time.time() - start_time
85
+ yield result, None, f"Processing frame {len(processed_frames)}... Elapsed time: {elapsed_time:.2f} seconds"
86
+
87
+ processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
88
+ processed_video = processed_video.set_audio(audio)
89
+
90
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
91
+ temp_filepath = temp_file.name
92
+ processed_video.write_videofile(temp_filepath, codec="libx264")
93
+
94
  elapsed_time = time.time() - start_time
95
+ yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
96
+ yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
97
 
98
+ except Exception as e:
99
+ print(f"Error: {e}")
100
+ elapsed_time = time.time() - start_time
101
+ yield gr.update(visible=False), gr.update(visible=True), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
102
+ yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
 
 
 
 
 
103
 
104
  def process(image, bg, fast_mode=False):
105
  image_size = image.size