Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import gradio as gr | |
title = "移ι€θζ― Demo" | |
description = "δΈε³εη ,θͺεε»ι€θζ―." | |
# Set up CUDA if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_float32_matmul_precision("high") | |
# Load the model | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
birefnet.to(device) | |
# Define image transformations | |
transform_image = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def load_img(image_path_or_url): | |
if image_path_or_url.startswith('http'): | |
response = requests.get(image_path_or_url) | |
img = Image.open(BytesIO(response.content)) | |
else: | |
img = Image.open(image_path_or_url) | |
return img.convert("RGB") | |
def process(image): | |
image_size = image.size | |
input_images = transform_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
# Create a new image with transparency | |
transparent_image = Image.new("RGBA", image.size) | |
transparent_image.paste(image, (0, 0)) | |
transparent_image.putalpha(mask) # Apply mask to the new image | |
return transparent_image # Return the new transparent image | |
def remove_background_gradio(image): | |
processed_img = process(image) | |
return processed_img | |
# Create the Gradio interface with drag-and-drop and paste functionality | |
demo = gr.Interface( | |
fn=remove_background_gradio, | |
inputs = gr.Image(type="pil"), # Remove 'source' argument | |
outputs = gr.Image(type="pil"), | |
title = title, | |
description = description, | |
examples=[['girl1.png'],['girl2.png'],['girl3.png'],['gonfu1.jpg'],['angel.png'],['statue.png']], | |
) | |
demo.launch(share=True) # Launch the interface and get a shareable link |