kharismagp commited on
Commit
273eb2f
·
verified ·
1 Parent(s): f4481e9

Upload 4 files

Browse files
Files changed (4) hide show
  1. example.jpeg +0 -0
  2. for_gradio.py +74 -0
  3. requirements.txt +9 -0
  4. runtime.txt +1 -0
example.jpeg ADDED
for_gradio.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import AutoModelForImageSegmentation
6
+
7
+ # Setup constants
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Define image transformation pipeline
11
+ transform_image = transforms.Compose([
12
+ transforms.Resize((1024, 1024)),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
15
+ ])
16
+
17
+ # Load the model ONCE globally
18
+ try:
19
+ torch.set_float32_matmul_precision("high")
20
+ model = AutoModelForImageSegmentation.from_pretrained(
21
+ "ZhengPeng7/BiRefNet_lite",
22
+ trust_remote_code=True
23
+ ).to(DEVICE)
24
+ print("Model loaded successfully.")
25
+ except Exception as e:
26
+ print(f"Error loading model: {str(e)}")
27
+ model = None
28
+
29
+ def process_image(image):
30
+ """Process a single image and remove its background"""
31
+ image = image.convert("RGB")
32
+ original_size = image.size
33
+ input_tensor = transform_image(image).unsqueeze(0).to(DEVICE)
34
+
35
+ with torch.no_grad():
36
+ preds = model(input_tensor)[-1].sigmoid().cpu()
37
+ pred = preds[0].squeeze()
38
+ mask = transforms.ToPILImage()(pred).resize(original_size)
39
+
40
+ result = image.copy()
41
+ result.putalpha(mask)
42
+
43
+ return result
44
+
45
+ def predict(image):
46
+ """Gradio interface function"""
47
+ if model is None:
48
+ raise gr.Error("Model not loaded. Check server logs.")
49
+ if image is None:
50
+ return None, None # Return None for both image and file
51
+
52
+ try:
53
+ result_image = process_image(image)
54
+ file_path = "processed_image.png"
55
+ result_image.save(file_path, "PNG")
56
+ return result_image, file_path
57
+
58
+ except Exception as e:
59
+ raise gr.Error(f"Error processing image: {e}")
60
+
61
+ # Gradio interface
62
+ interface = gr.Interface(
63
+ fn=predict,
64
+ inputs=gr.Image(type="pil"),
65
+ outputs=[
66
+ gr.Image(type="pil", label="Processed Image"),
67
+ gr.File(label="Download Processed Image")
68
+ ],
69
+ examples=[['example.jpeg']],
70
+ title="Background Removal App",
71
+ description="Upload an image to remove its background and download the processed image as a PNG."
72
+ )
73
+
74
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # loadimg
2
+ torch
3
+ torchvision
4
+ transformers
5
+ kornia
6
+ einops
7
+ timm
8
+ streamlit
9
+ gradio
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python_version=3.10