bgremoval / app.py
petergpt's picture
Update app.py
b7a75e4 verified
raw
history blame
4.69 kB
import time
import torch
import gc
from transformers import AutoConfig, AutoModelForImageSegmentation
from PIL import Image
from torchvision import transforms
import gradio as gr
def load_model():
# Fetch the config first (with trust_remote_code=True)
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
# Ensure it's not treated as a seq2seq model
config.is_encoder_decoder = False
# Optionally, block calls to get_text_config if needed:
# config.get_text_config = lambda decoder=True: None
# Now load the model with our tweaked config
model = AutoModelForImageSegmentation.from_pretrained(
"zhengpeng7/BiRefNet_lite",
config=config,
trust_remote_code=True
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
return model, device
birefnet, device = load_model()
# Preprocessing
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference(images, model, device):
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
try:
with torch.no_grad():
# If the last layer is returned as [-1],
# adjust accordingly or see how your model outputs are structured
preds = model(input_tensor)[-1].sigmoid().cpu()
except torch.OutOfMemoryError:
del input_tensor
torch.cuda.empty_cache()
raise
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
# Cleanup
del input_tensor, preds
gc.collect()
torch.cuda.empty_cache()
return results
def binary_search_max(images):
# After OOM, try to find max feasible batch
low, high = 1, len(images)
best = None
best_count = 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
global birefnet, device
birefnet, device = load_model() # re-init to reduce memory fragmentation
res = run_inference(batch, birefnet, device)
best = res
best_count = mid
low = mid + 1
except torch.OutOfMemoryError:
high = mid - 1
return best, best_count
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# First attempt: all images
try:
results = run_inference(images, birefnet, device)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# OOM occurred, try fallback
oom_time = time.time()
initial_attempt_time = oom_time - start_time
best, best_count = binary_search_max(images)
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image works
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time including fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Found that {best_count} images can be processed without OOM.\n"
f"Total time including fallback attempts: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal with On-Demand Fallback",
description="Upload as many images as you want. If OOM occurs, fallback logic will find the max feasible number."
)
if __name__ == "__main__":
iface.launch()