import gradio as gr from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation # Setup constants DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Define image transformation pipeline transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the model ONCE globally try: torch.set_float32_matmul_precision("high") model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet_lite", trust_remote_code=True ).to(DEVICE) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {str(e)}") model = None def process_image(image): """Process a single image and remove its background""" image = image.convert("RGB") original_size = image.size input_tensor = transform_image(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred).resize(original_size) result = image.copy() result.putalpha(mask) return result def predict(image): """Gradio interface function""" if model is None: raise gr.Error("Model not loaded. Check server logs.") if image is None: return None, None # Return None for both image and file try: result_image = process_image(image) file_path = "processed_image.png" result_image.save(file_path, "PNG") return result_image, file_path except Exception as e: raise gr.Error(f"Error processing image: {e}") # Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Image(type="pil", label="Processed Image"), gr.File(label="Download Processed Image") ], examples=[['example.jpeg']], title="Background Removal App", description="Upload an image to remove its background and download the processed image as a PNG." ) interface.launch()