Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
# Setup constants | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define image transformation pipeline | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Load the model ONCE globally | |
try: | |
torch.set_float32_matmul_precision("high") | |
model = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet_lite", | |
trust_remote_code=True | |
).to(DEVICE) | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
model = None | |
def process_image(image): | |
"""Process a single image and remove its background""" | |
image = image.convert("RGB") | |
original_size = image.size | |
input_tensor = transform_image(image).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
preds = model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
mask = transforms.ToPILImage()(pred).resize(original_size) | |
result = image.copy() | |
result.putalpha(mask) | |
return result | |
def predict(image): | |
"""Gradio interface function""" | |
if model is None: | |
raise gr.Error("Model not loaded. Check server logs.") | |
if image is None: | |
return None, None # Return None for both image and file | |
try: | |
result_image = process_image(image) | |
file_path = "processed_image.png" | |
result_image.save(file_path, "PNG") | |
return result_image, file_path | |
except Exception as e: | |
raise gr.Error(f"Error processing image: {e}") | |
# Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Image(type="pil", label="Processed Image"), | |
gr.File(label="Download Processed Image") | |
], | |
examples=[['example.jpeg']], | |
title="Background Removal App", | |
description="Upload an image to remove its background and download the processed image as a PNG." | |
) | |
interface.launch() | |