KingNish commited on
Commit
262a1a2
1 Parent(s): 0b6644c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -110
app.py CHANGED
@@ -11,133 +11,96 @@ import numpy as np
11
  import os
12
  import tempfile
13
  import uuid
14
- from concurrent.futures import ThreadPoolExecutor
15
- import torch.nn as nn
16
- import torch.cuda.amp # for mixed precision training
17
 
18
- # Enable tensor cores for faster computation
19
- torch.set_float32_matmul_precision("high")
20
- torch.backends.cudnn.benchmark = True # Enable cudnn autotuner
21
 
22
- # Initialize model with optimization flags
23
  birefnet = AutoModelForImageSegmentation.from_pretrained(
24
  "ZhengPeng7/BiRefNet", trust_remote_code=True
 
 
 
 
 
 
 
25
  )
26
- birefnet.to("cuda").eval() # Ensure model is in eval mode
27
- birefnet = torch.jit.script(birefnet) # JIT compilation for faster inference
28
-
29
- # Pre-compile transforms for better performance
30
- transform_image = transforms.Compose([
31
- transforms.Resize((1024, 1024), antialias=True),
32
- transforms.ToTensor(),
33
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
- ])
35
-
36
- # Increased batch size for better GPU utilization
37
- BATCH_SIZE = 8 # Increased from 3
38
- NUM_WORKERS = 4 # For parallel processing
39
-
40
- # Create a thread pool for parallel processing
41
- executor = ThreadPoolExecutor(max_workers=NUM_WORKERS)
42
-
43
- def process_batch(batch_data):
44
- """Process a batch of frames in parallel"""
45
- images, backgrounds, image_sizes = zip(*batch_data)
46
-
47
- # Stack images for batch processing
48
- input_tensor = torch.stack(images).to("cuda")
49
-
50
- # Use automatic mixed precision for faster computation
51
- with torch.cuda.amp.autocast():
52
- with torch.no_grad():
53
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
54
-
55
- processed_frames = []
56
- for pred, bg, size in zip(preds, backgrounds, image_sizes):
57
- mask = transforms.ToPILImage()(pred.squeeze()).resize(size)
58
-
59
- if isinstance(bg, str) and bg.startswith("#"):
60
- color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5))
61
- background = Image.new("RGBA", size, color_rgb + (255,))
62
- elif isinstance(bg, Image.Image):
63
- background = bg.convert("RGBA").resize(size)
64
- else:
65
- background = Image.open(bg).convert("RGBA").resize(size)
66
-
67
- # Use PIL's faster composite operation
68
- image = Image.composite(images[0].resize(size), background, mask)
69
- processed_frames.append(np.array(image))
70
-
71
- return processed_frames
72
 
73
  @spaces.GPU
74
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
75
  try:
76
- # Load video more efficiently
77
- video = mp.VideoFileClip(vid, audio_buffersize=2000)
78
  if fps == 0:
79
  fps = video.fps
80
  audio = video.audio
81
-
82
- # Pre-calculate video parameters
83
- total_frames = int(video.fps * video.duration)
84
- frames = list(video.iter_frames(fps=fps)) # Load all frames at once
85
-
86
- # Pre-process background if using video
87
  if bg_type == "Video":
88
- bg_video_clip = mp.VideoFileClip(bg_video)
89
- if bg_video_clip.duration < video.duration:
90
  if video_handling == "slow_down":
91
- bg_video_clip = bg_video_clip.fx(mp.vfx.speedx,
92
- factor=video.duration / bg_video_clip.duration)
93
  else:
94
- multiplier = int(video.duration / bg_video_clip.duration + 1)
95
- bg_video_clip = mp.concatenate_videoclips([bg_video_clip] * multiplier)
96
- background_frames = list(bg_video_clip.iter_frames(fps=fps))
97
-
98
- # Process frames in batches
99
- processed_frames = []
100
- for i in range(0, len(frames), BATCH_SIZE):
101
- batch_frames = frames[i:i + BATCH_SIZE]
102
- batch_data = []
103
-
104
- for j, frame in enumerate(batch_frames):
105
- pil_image = Image.fromarray(frame)
106
- image_size = pil_image.size
107
- transformed_image = transform_image(pil_image)
108
-
109
- if bg_type == "Color":
110
- bg = color
111
- elif bg_type == "Image":
112
- bg = bg_image
113
- else: # Video
114
- frame_idx = (i + j) % len(background_frames)
115
- bg = Image.fromarray(background_frames[frame_idx])
116
-
117
- batch_data.append((transformed_image, bg, image_size))
118
-
119
- # Process batch
120
- batch_results = process_batch(batch_data)
121
- processed_frames.extend(batch_results)
122
-
123
- # Yield progress updates
124
- if len(batch_results) > 0:
125
- yield batch_results[-1], None
126
-
127
- # Create output video
 
 
 
 
 
 
 
 
 
 
 
 
128
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
129
- if audio is not None:
130
- processed_video = processed_video.set_audio(audio)
131
-
132
- # Use temporary file
133
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
134
- output_path = tmp_file.name
135
- processed_video.write_videofile(output_path, codec="libx264",
136
- preset='ultrafast', threads=NUM_WORKERS)
137
-
138
  yield gr.update(visible=False), gr.update(visible=True)
139
- yield processed_frames[-1], output_path
140
-
141
  except Exception as e:
142
  print(f"Error: {e}")
143
  yield gr.update(visible=False), gr.update(visible=True)
 
11
  import os
12
  import tempfile
13
  import uuid
 
 
 
14
 
15
+ torch.set_float32_matmul_precision("highest")
 
 
16
 
 
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
+ ).to("cuda")
20
+ transform_image = transforms.Compose(
21
+ [
22
+ transforms.Resize((1024, 1024)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
+ ]
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @spaces.GPU
29
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
30
  try:
31
+ video = mp.VideoFileClip(vid)
 
32
  if fps == 0:
33
  fps = video.fps
34
  audio = video.audio
35
+ frames = video.iter_frames(fps=fps)
36
+ processed_frames = []
37
+ yield gr.update(visible=True), gr.update(visible=False)
38
+
 
 
39
  if bg_type == "Video":
40
+ background_video = mp.VideoFileClip(bg_video)
41
+ if background_video.duration < video.duration:
42
  if video_handling == "slow_down":
43
+ background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration)
 
44
  else:
45
+ background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
46
+ background_frames = list(background_video.iter_frames(fps=fps))
47
+ elif bg_type in ["Color", "Image"]:
48
+ # Prepare background once if it's a static image or color
49
+ if bg_type == "Color":
50
+ color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
51
+ background_pil = Image.new("RGBA", (1024, 1024), color_rgb + (255,))
52
+ else: # bg_type == "Image":
53
+ background_pil = Image.open(bg_image).convert("RGBA").resize((1024, 1024))
54
+ background_tensor = transforms.ToTensor(background_pil).to("cuda")
55
+ else:
56
+ background_tensor = None
57
+
58
+
59
+ bg_frame_index = 0
60
+ frame_batch = []
61
+ for i, frame in enumerate(frames):
62
+ frame = Image.fromarray(frame)
63
+ frame = transforms.ToTensor(frame).to('cuda')
64
+ frame_batch.append(frame)
65
+
66
+ if len(frame_batch) >= 3 or i == int(video.fps * video.duration) - 1 :
67
+ input_images = torch.stack(frame_batch).to("cuda")
68
+ with torch.no_grad():
69
+ preds = birefnet(input_images)[-1].sigmoid()
70
+ for j, pred in enumerate(preds):
71
+ if bg_type == "Video":
72
+ if video_handling == "slow_down":
73
+ background_frame = background_frames[bg_frame_index % len(background_frames)]
74
+ bg_frame_index += 1
75
+ background_image = Image.fromarray(background_frame).resize((1024, 1024))
76
+ background_tensor = transforms.ToTensor(background_image).to("cuda")
77
+ else: # video_handling == "loop"
78
+ background_frame = background_frames[bg_frame_index % len(background_frames)]
79
+ bg_frame_index += 1
80
+ background_image = Image.fromarray(background_frame).resize((1024, 1024))
81
+ background_tensor = transforms.ToTensor(background_image).to("cuda")
82
+ mask = transforms.ToPILImage()(pred.cpu().squeeze())
83
+ processed_image = Image.composite(transforms.ToPILImage()(frame_batch[j].cpu()), transforms.ToPILImage()(background_tensor.cpu()), mask).resize(video.size)
84
+
85
+ processed_frames.append(np.array(processed_image))
86
+ yield processed_image, None
87
+
88
+ frame_batch = []
89
+
90
+
91
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
92
+ processed_video = processed_video.set_audio(audio)
93
+
94
+ temp_dir = "temp"
95
+ os.makedirs(temp_dir, exist_ok=True)
96
+ unique_filename = str(uuid.uuid4()) + ".mp4"
97
+ temp_filepath = os.path.join(temp_dir, unique_filename)
98
+
99
+ processed_video.write_videofile(temp_filepath, codec="libx264", logger=None)
100
+
101
  yield gr.update(visible=False), gr.update(visible=True)
102
+ yield processed_image, temp_filepath
103
+
104
  except Exception as e:
105
  print(f"Error: {e}")
106
  yield gr.update(visible=False), gr.update(visible=True)