MonsterMMORPG commited on
Commit
604e3cb
1 Parent(s): 47565bc

Upload RemoveBG_By_SECourses.py

Browse files
Files changed (1) hide show
  1. RemoveBG_By_SECourses.py +200 -0
RemoveBG_By_SECourses.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ import argparse
7
+ from pathlib import Path
8
+ from glob import glob
9
+ from typing import Optional, Tuple, List
10
+ from PIL import Image
11
+ from transformers import AutoModelForImageSegmentation
12
+ from torchvision import transforms
13
+ import time
14
+ import os
15
+ import platform
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Run the image segmentation app")
19
+ parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface")
20
+ return parser.parse_args()
21
+
22
+ torch.set_float32_matmul_precision('high')
23
+ torch.jit.script = lambda f: f
24
+
25
+ os.environ['HOME'] = os.path.expanduser('~')
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ def open_folder():
30
+ open_folder_path = os.path.abspath("results")
31
+ if platform.system() == "Windows":
32
+ os.startfile(open_folder_path)
33
+ elif platform.system() == "Linux":
34
+ os.system(f'xdg-open "{open_folder_path}"')
35
+
36
+ class ImagePreprocessor():
37
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
38
+ self.transform_image = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ ])
41
+ self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+
43
+ def proc(self, image: Image.Image) -> torch.Tensor:
44
+ image = image.convert('RGB') # Convert to RGB
45
+ image = self.transform_image(image)
46
+ return self.normalize(image)
47
+
48
+ usage_to_weights_file = {
49
+ 'General': 'BiRefNet',
50
+ 'General-Lite': 'BiRefNet_T',
51
+ 'Portrait': 'BiRefNet-portrait',
52
+ 'DIS': 'BiRefNet-DIS5K',
53
+ 'HRSOD': 'BiRefNet-HRSOD',
54
+ 'COD': 'BiRefNet-COD',
55
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
56
+ }
57
+
58
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
59
+ birefnet.to(device)
60
+ birefnet.eval()
61
+
62
+ def process_single_image(image_path: str, resolution: str, output_folder: str) -> Tuple[str, str, float]:
63
+ start_time = time.time()
64
+
65
+ image = Image.open(image_path).convert('RGBA')
66
+
67
+ if resolution == '':
68
+ resolution = f"{image.width}x{image.height}"
69
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
70
+
71
+ image_shape = image.size[::-1]
72
+ image_pil = image.resize(tuple(resolution))
73
+
74
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
75
+ image_proc = image_preprocessor.proc(image_pil)
76
+ image_proc = image_proc.unsqueeze(0)
77
+
78
+ with torch.no_grad():
79
+ scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
80
+
81
+ if device == 'cuda':
82
+ scaled_pred_tensor = scaled_pred_tensor.cpu()
83
+
84
+ pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
85
+
86
+ pred_rgba = np.zeros((*pred.shape, 4), dtype=np.uint8)
87
+ pred_rgba[..., :3] = (pred[..., np.newaxis] * 255).astype(np.uint8)
88
+ pred_rgba[..., 3] = (pred * 255).astype(np.uint8)
89
+
90
+ image_array = np.array(image)
91
+ image_pred = image_array * (pred_rgba / 255.0)
92
+
93
+ output_image = Image.fromarray(image_pred.astype(np.uint8), 'RGBA')
94
+
95
+ base_filename = os.path.splitext(os.path.basename(image_path))[0]
96
+ output_path = os.path.join(output_folder, f"{base_filename}.png")
97
+
98
+ counter = 1
99
+ while os.path.exists(output_path):
100
+ output_path = os.path.join(output_folder, f"{base_filename}_{counter:04d}.png")
101
+ counter += 1
102
+
103
+ output_image.save(output_path)
104
+
105
+ processing_time = time.time() - start_time
106
+ print(f"Processed {image_path} in {processing_time:.4f} seconds") # Added this line to print processing time
107
+ return image_path, output_path, processing_time
108
+
109
+ def predict(
110
+ image: str,
111
+ resolution: str,
112
+ weights_file: Optional[str],
113
+ batch_folder: Optional[str] = None,
114
+ output_folder: Optional[str] = None,
115
+ is_batch: bool = False
116
+ ) -> Tuple[str, List[Tuple[str, str]]]:
117
+ global birefnet
118
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
119
+ print('Using weights:', _weights_file)
120
+ birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
121
+ birefnet.to(device)
122
+ birefnet.eval()
123
+
124
+ if not output_folder:
125
+ output_folder = 'results'
126
+ os.makedirs(output_folder, exist_ok=True)
127
+
128
+ results = []
129
+
130
+ if is_batch and batch_folder:
131
+ image_files = glob(os.path.join(batch_folder, '*'))
132
+ total_images = len(image_files)
133
+ processed_images = 0
134
+ start_time = time.time()
135
+
136
+ for img_path in image_files:
137
+ try:
138
+ input_path, output_path, proc_time = process_single_image(img_path, resolution, output_folder)
139
+ results.append((output_path, f"{proc_time:.4f} seconds"))
140
+ processed_images += 1
141
+ elapsed_time = time.time() - start_time
142
+ avg_time_per_image = elapsed_time / processed_images
143
+ estimated_time_left = avg_time_per_image * (total_images - processed_images)
144
+
145
+ status = f"Processed {processed_images}/{total_images} images. Estimated time left: {estimated_time_left:.2f} seconds"
146
+ print(status)
147
+ except Exception as e:
148
+ print(f"Error processing {img_path}: {str(e)}")
149
+ continue
150
+
151
+ return f"Batch processing complete. Processed {processed_images}/{total_images} images.", results
152
+ else:
153
+ input_path, output_path, proc_time = process_single_image(image, resolution, output_folder)
154
+ results.append((output_path, f"{proc_time:.4f} seconds"))
155
+ return "Single image processing complete.", results
156
+
157
+ def create_interface():
158
+ with gr.Blocks() as demo:
159
+ gr.Markdown("## SECourses Improved BiRefNet V1 'Bilateral Reference for High-Resolution Dichotomous Image Segmentation' APP - SOTA Background Remover")
160
+ gr.Markdown("## Most Advanced Latest Version On : https://www.patreon.com/posts/109913645")
161
+
162
+ with gr.Row():
163
+ input_image = gr.Image(type="filepath", label="Input Image",height=512)
164
+ output_image = gr.Gallery(label="Output Image", elem_id="gallery",height=512)
165
+
166
+
167
+ with gr.Row():
168
+ resolution = gr.Textbox(label="Resolution", placeholder="1024x1024 - Optional - Don't enter to use original image resolution - Higher res uses more VRAM but still works perfect with shared VRAM so fast")
169
+ weights_file = gr.Dropdown(choices=list(usage_to_weights_file.keys()), value="General", label="Weights File")
170
+ btn_open_outputs = gr.Button("Open Results Folder")
171
+ btn_open_outputs.click(fn=open_folder)
172
+
173
+ with gr.Row():
174
+ batch_folder = gr.Textbox(label="Batch Folder Path")
175
+ output_folder = gr.Textbox(label="Output Folder Path", value="results")
176
+
177
+ with gr.Row():
178
+ submit_button = gr.Button("Process")
179
+ batch_button = gr.Button("Process Batch")
180
+
181
+ output_text = gr.Textbox(label="Processing Status")
182
+
183
+ submit_button.click(
184
+ predict,
185
+ inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=False, visible=False)],
186
+ outputs=[output_text, output_image]
187
+ )
188
+
189
+ batch_button.click(
190
+ predict,
191
+ inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=True, visible=False)],
192
+ outputs=[output_text, output_image]
193
+ )
194
+
195
+ return demo
196
+
197
+ if __name__ == "__main__":
198
+ args = parse_args()
199
+ demo = create_interface()
200
+ demo.launch(inbrowser=True, share=args.share)