Invictus-Jai commited on
Commit
f0ce5fe
·
1 Parent(s): 7ce0457

Add Application Files

Browse files
Files changed (7) hide show
  1. Dockerfile +13 -0
  2. app.py +105 -0
  3. requirements.txt +11 -0
  4. static/script.js +165 -0
  5. static/style.css +201 -0
  6. templates/main.html +37 -0
  7. utils/image_segmenter.py +56 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.0
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, UploadFile, File, HTTPException
2
+ import uvicorn
3
+ from fastapi.responses import HTMLResponse, StreamingResponse
4
+ from fastapi.templating import Jinja2Templates
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from os import path, makedirs
8
+ from PIL import Image
9
+ import logging
10
+ from io import BytesIO
11
+ from typing import Tuple
12
+ from utils.image_segmenter import ImageSegmenter
13
+ import zipfile
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI(title="Image Background Remover",
20
+ description="API for removing image backgrounds using ML",
21
+ version="1.0.0")
22
+
23
+ # Define allowed origins explicitly for security
24
+ origins = [
25
+ "http://127.0.0.1:5500",
26
+ "http://localhost:5500"
27
+ ]
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=origins, # Use the defined origins instead of "*"
32
+ allow_credentials=True,
33
+ allow_methods=["GET", "POST"], # Specify only needed methods
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # Set up the templates
38
+ templates = Jinja2Templates(directory="templates")
39
+ app.mount("/static", StaticFiles(directory="static"), name="static")
40
+
41
+ # Initialize ImageSegmenter once at startup
42
+ segmenter = ImageSegmenter()
43
+
44
+ @app.get("/")
45
+ async def index(request: Request) -> HTMLResponse:
46
+ return templates.TemplateResponse("main.html", {"request": request})
47
+
48
+ @app.post("/api/remove-background/")
49
+ async def remove_background(file_obj: UploadFile = File(...)) -> StreamingResponse:
50
+ try:
51
+ # Validate file type
52
+ if not file_obj.content_type.startswith('image/'):
53
+ raise HTTPException(status_code=400, detail="File must be an image")
54
+
55
+ # Read image with error handling
56
+ try:
57
+ image_content = await file_obj.read()
58
+ image = Image.open(BytesIO(image_content))
59
+ except Exception as e:
60
+ logger.error(f"Error reading image: {str(e)}")
61
+ raise HTTPException(status_code=400, detail="Invalid image file")
62
+
63
+ # Process image
64
+ try:
65
+ image, mask = await segmenter.segment(image) # Fixed typo in method name
66
+ except Exception as e:
67
+ logger.error(f"Error processing image: {str(e)}")
68
+ raise HTTPException(status_code=500, detail="Error processing image")
69
+
70
+ # Create ZIP file in memory
71
+ zip_buffer = BytesIO()
72
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
73
+ # Save processed image
74
+ image_buffer = BytesIO()
75
+ image.save(image_buffer, "PNG", optimize=True)
76
+ image_buffer.seek(0)
77
+ zip_file.writestr('processed_image.png', image_buffer.getvalue())
78
+
79
+ # Save mask
80
+ mask_buffer = BytesIO()
81
+ mask.save(mask_buffer, "PNG", optimize=True)
82
+ mask_buffer.seek(0)
83
+ zip_file.writestr('mask.png', mask_buffer.getvalue())
84
+
85
+ zip_buffer.seek(0)
86
+
87
+ return StreamingResponse(
88
+ zip_buffer,
89
+ media_type="application/zip",
90
+ headers={
91
+ "Content-Disposition": f"attachment; filename=result_{file_obj.filename}.zip"
92
+ }
93
+ )
94
+ except Exception as e:
95
+ logger.error(f"Unexpected error: {str(e)}")
96
+ raise HTTPException(status_code=500, detail="Internal server error")
97
+
98
+ # if __name__ == "__main__":
99
+ # uvicorn.run(
100
+ # app,
101
+ # host="127.0.0.1",
102
+ # port=8000,
103
+ # log_level="info",
104
+ # reload=True # Enable auto-reload during development
105
+ # )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ torch
5
+ torchvision
6
+ transformers
7
+ python-multipart
8
+ jinja2
9
+ aiofiles
10
+ kornia
11
+ pyngrok
static/script.js ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const uploadArea = document.getElementById("uploadArea");
2
+ const fileInput = document.getElementById("fileInput");
3
+ const previewImage = document.getElementById("previewImage");
4
+ const resultImage = document.getElementById("resultImage");
5
+ const removeBackgroundBtn = document.getElementById("removeBackgroundBtn");
6
+ const uploadLoading = document.getElementById("uploadLoading");
7
+ const resultLoading = document.getElementById("resultLoading");
8
+
9
+ // Add these new elements to your HTML
10
+ const imageContainer = document.querySelector(".image-box:nth-child(2)");
11
+ const toggleContainer = document.createElement("div");
12
+ toggleContainer.className = "toggle-container";
13
+ toggleContainer.innerHTML = `
14
+ <select id="viewSelector" class="view-selector">
15
+ <option value="processed">Processed Image</option>
16
+ <option value="mask">Mask</option>
17
+ </select>
18
+ `;
19
+ imageContainer.insertBefore(
20
+ toggleContainer,
21
+ imageContainer.querySelector(".upload-area")
22
+ );
23
+
24
+ // Store the extracted images
25
+ let processedImageUrl = null;
26
+ let maskImageUrl = null;
27
+
28
+ uploadArea.addEventListener("click", () => {
29
+ fileInput.click();
30
+ });
31
+
32
+ uploadArea.addEventListener("dragover", (e) => {
33
+ e.preventDefault();
34
+ uploadArea.style.borderColor = "#666666";
35
+ uploadArea.style.backgroundColor = "#333333";
36
+ });
37
+
38
+ uploadArea.addEventListener("dragleave", (e) => {
39
+ e.preventDefault();
40
+ uploadArea.style.borderColor = "#444444";
41
+ uploadArea.style.backgroundColor = "#222222";
42
+ });
43
+
44
+ uploadArea.addEventListener("drop", (e) => {
45
+ e.preventDefault();
46
+ uploadArea.style.borderColor = "#444444";
47
+ uploadArea.style.backgroundColor = "#222222";
48
+
49
+ const file = e.dataTransfer.files[0];
50
+ if (file && file.type.startsWith("image/")) {
51
+ handleImageUpload(file);
52
+ }
53
+ });
54
+
55
+ fileInput.addEventListener("change", (e) => {
56
+ const file = e.target.files[0];
57
+ if (file) {
58
+ handleImageUpload(file);
59
+ }
60
+ });
61
+
62
+ function handleImageUpload(file) {
63
+ const reader = new FileReader();
64
+
65
+ uploadLoading.style.display = "block";
66
+
67
+ reader.onload = (e) => {
68
+ previewImage.src = e.target.result;
69
+ previewImage.style.display = "block";
70
+ uploadArea.querySelector("p").style.display = "none";
71
+ removeBackgroundBtn.disabled = false;
72
+ uploadLoading.style.display = "none";
73
+ };
74
+
75
+ reader.readAsDataURL(file);
76
+ }
77
+
78
+ // Function to extract images from ZIP
79
+ async function extractImagesFromZip(blob) {
80
+ const zip = await JSZip.loadAsync(blob);
81
+ const files = Object.values(zip.files);
82
+
83
+ const processedImageFile = files.find(
84
+ (file) => file.name === "processed_image.png"
85
+ );
86
+ const maskFile = files.find((file) => file.name === "mask.png");
87
+
88
+ if (!processedImageFile || !maskFile) {
89
+ throw new Error("Missing required images in ZIP file");
90
+ }
91
+
92
+ // Get blobs for both images
93
+ const processedImageBlob = await processedImageFile.async("blob");
94
+ const maskBlob = await maskFile.async("blob");
95
+
96
+ // Create URLs for the images
97
+ if (processedImageUrl) URL.revokeObjectURL(processedImageUrl);
98
+ if (maskImageUrl) URL.revokeObjectURL(maskImageUrl);
99
+
100
+ processedImageUrl = URL.createObjectURL(processedImageBlob);
101
+ maskImageUrl = URL.createObjectURL(maskBlob);
102
+
103
+ return { processedImageUrl, maskImageUrl };
104
+ }
105
+
106
+ // Add event listener for view selector
107
+ document.getElementById("viewSelector").addEventListener("change", (e) => {
108
+ resultImage.src =
109
+ e.target.value === "processed" ? processedImageUrl : maskImageUrl;
110
+ });
111
+
112
+ removeBackgroundBtn.addEventListener("click", async () => {
113
+ try {
114
+ resultLoading.style.display = "block";
115
+ removeBackgroundBtn.disabled = true;
116
+ document.getElementById("viewSelector").disabled = true;
117
+
118
+ // Convert base64 to blob
119
+ const image = fileInput.files[0];
120
+
121
+ // Create FormData
122
+ const formData = new FormData();
123
+ formData.append("file_obj", image);
124
+
125
+ // Send to backend
126
+ const response = await fetch("/api/remove-background/", {
127
+ method: "POST",
128
+ body: formData,
129
+ });
130
+
131
+ if (!response.ok) {
132
+ throw new Error("Failed to remove background");
133
+ }
134
+
135
+ const blob = await response.blob();
136
+
137
+ // Extract and display images from ZIP
138
+ const { processedImageUrl: processedUrl, maskImageUrl: maskUrl } =
139
+ await extractImagesFromZip(blob);
140
+
141
+ // Show processed image by default
142
+ resultImage.src = processedUrl;
143
+ resultImage.style.display = "block";
144
+ document.getElementById("viewSelector").disabled = false;
145
+
146
+ // Create download link
147
+ const downloadLink = document.createElement("a");
148
+ downloadLink.href = URL.createObjectURL(blob);
149
+ downloadLink.download = "images.zip";
150
+ downloadLink.className = "download-link";
151
+ downloadLink.innerText = "Download ZIP";
152
+
153
+ // Add download link below the image container
154
+ const downloadContainer = document.createElement("div");
155
+ downloadContainer.className = "download-container";
156
+ downloadContainer.appendChild(downloadLink);
157
+ imageContainer.appendChild(downloadContainer);
158
+ } catch (error) {
159
+ console.error("Error:", error);
160
+ alert("Failed to process image. Please try again.");
161
+ } finally {
162
+ resultLoading.style.display = "none";
163
+ removeBackgroundBtn.disabled = false;
164
+ }
165
+ });
static/style.css ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ margin: 0;
3
+ padding: 0;
4
+ box-sizing: border-box;
5
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
6
+ }
7
+
8
+ body {
9
+ min-height: 100vh;
10
+ background: linear-gradient(135deg, #1a1a1a, #000000);
11
+ padding: 2rem;
12
+ color: #ffffff;
13
+ }
14
+
15
+ .container {
16
+ max-width: 1200px;
17
+ margin: 0 auto;
18
+ }
19
+
20
+ .title {
21
+ text-align: center;
22
+ color: #ffffff;
23
+ margin-bottom: 2rem;
24
+ font-size: 2rem;
25
+ text-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
26
+ }
27
+
28
+ .image-containers {
29
+ display: grid;
30
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
31
+ gap: 2rem;
32
+ margin-bottom: 2rem;
33
+ }
34
+
35
+ .image-box {
36
+ background-color: #2a2a2a;
37
+ border-radius: 12px;
38
+ padding: 1rem;
39
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.4);
40
+ height: 400px;
41
+ display: flex;
42
+ flex-direction: column;
43
+ position: relative;
44
+ border: 1px solid #333333;
45
+ }
46
+
47
+ .image-box h2 {
48
+ color: #ffffff;
49
+ margin-bottom: 1rem;
50
+ font-size: 1.2rem;
51
+ }
52
+
53
+ .upload-area {
54
+ flex: 1;
55
+ border: 2px dashed #444444;
56
+ border-radius: 8px;
57
+ display: flex;
58
+ align-items: center;
59
+ justify-content: center;
60
+ cursor: pointer;
61
+ transition: all 0.3s ease;
62
+ position: relative;
63
+ overflow: hidden;
64
+ background-color: #222222;
65
+ }
66
+
67
+ .upload-area:hover {
68
+ border-color: #666666;
69
+ background-color: #333333;
70
+ }
71
+
72
+ .upload-area p {
73
+ color: #999999;
74
+ text-align: center;
75
+ padding: 1rem;
76
+ }
77
+
78
+ .preview-image {
79
+ width: 100%;
80
+ height: 100%;
81
+ object-fit: contain;
82
+ display: none;
83
+ }
84
+
85
+ .result-image {
86
+ width: 100%;
87
+ height: 100%;
88
+ object-fit: contain;
89
+ }
90
+
91
+ .remove-bg-btn {
92
+ display: block;
93
+ width: 100%;
94
+ max-width: 300px;
95
+ margin: 0 auto;
96
+ padding: 1rem 2rem;
97
+ background-color: #ffffff;
98
+ color: #000000;
99
+ border: none;
100
+ border-radius: 8px;
101
+ cursor: pointer;
102
+ font-size: 1rem;
103
+ transition: all 0.3s ease;
104
+ font-weight: 500;
105
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
106
+ }
107
+
108
+ .remove-bg-btn:hover {
109
+ background-color: #f0f0f0;
110
+ transform: translateY(-2px);
111
+ box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3);
112
+ }
113
+
114
+ .remove-bg-btn:disabled {
115
+ background-color: #444444;
116
+ color: #666666;
117
+ cursor: not-allowed;
118
+ transform: none;
119
+ box-shadow: none;
120
+ }
121
+
122
+ .loading {
123
+ position: absolute;
124
+ top: 50%;
125
+ left: 50%;
126
+ transform: translate(-50%, -50%);
127
+ display: none;
128
+ }
129
+
130
+ .loading::after {
131
+ content: "";
132
+ width: 40px;
133
+ height: 40px;
134
+ border: 4px solid #333333;
135
+ border-top: 4px solid #ffffff;
136
+ border-radius: 50%;
137
+ animation: spin 1s linear infinite;
138
+ display: block;
139
+ }
140
+
141
+ @keyframes spin {
142
+ 0% {
143
+ transform: rotate(0deg);
144
+ }
145
+ 100% {
146
+ transform: rotate(360deg);
147
+ }
148
+ }
149
+
150
+ @media (max-width: 768px) {
151
+ body {
152
+ padding: 1rem;
153
+ }
154
+
155
+ .image-box {
156
+ height: 300px;
157
+ }
158
+
159
+ .title {
160
+ font-size: 1.5rem;
161
+ }
162
+ }
163
+
164
+ .toggle-container {
165
+ margin-bottom: 1rem;
166
+ text-align: center;
167
+ }
168
+
169
+ .view-selector {
170
+ padding: 0.5rem;
171
+ border-radius: 4px;
172
+ background-color: #333333;
173
+ color: #ffffff;
174
+ border: 1px solid #444444;
175
+ cursor: pointer;
176
+ width: 150px;
177
+ }
178
+
179
+ .view-selector:disabled {
180
+ opacity: 0.5;
181
+ cursor: not-allowed;
182
+ }
183
+
184
+ .download-container {
185
+ text-align: center;
186
+ margin-top: 1rem;
187
+ }
188
+
189
+ .download-link {
190
+ display: inline-block;
191
+ padding: 0.5rem 1rem;
192
+ background-color: #333333;
193
+ color: #ffffff;
194
+ text-decoration: none;
195
+ border-radius: 4px;
196
+ transition: all 0.3s ease;
197
+ }
198
+
199
+ .download-link:hover {
200
+ background-color: #444444;
201
+ }
templates/main.html ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Background Remover</title>
7
+ <link rel="stylesheet" href="/static/style.css">
8
+
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1 class="title">Background Remover</h1>
13
+ <div class="image-containers">
14
+ <div class="image-box">
15
+ <h2>Upload Image</h2>
16
+ <div class="upload-area" id="uploadArea">
17
+ <p>Click or drag image here</p>
18
+ <input type="file" id="fileInput" hidden accept="image/*">
19
+ <img id="previewImage" class="preview-image">
20
+ <div class="loading" id="uploadLoading"></div>
21
+ </div>
22
+ </div>
23
+ <div class="image-box">
24
+ <h2>Result</h2>
25
+ <div class="upload-area">
26
+ <img id="resultImage" class="result-image">
27
+ <div class="loading" id="resultLoading"></div>
28
+ </div>
29
+ </div>
30
+ </div>
31
+ <button class="remove-bg-btn" id="removeBackgroundBtn" disabled>Remove Background</button>
32
+ </div>
33
+
34
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jszip/3.10.1/jszip.min.js"></script>
35
+ <script src="/static/script.js"></script>
36
+ </body>
37
+ </html>
utils/image_segmenter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torchvision import transforms # type: ignore
4
+ from transformers import AutoModelForImageSegmentation
5
+ from typing import Tuple
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class ImageSegmenter:
11
+ def __init__(self, image_size: Tuple[int, int] = (1024, 1024)):
12
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ logger.info(f"Using device: {self.device}")
14
+
15
+ try:
16
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
17
+ "ZhengPeng7/BiRefNet",
18
+ trust_remote_code=True
19
+ )
20
+ self.birefnet.to(self.device)
21
+ self.birefnet.eval() # Set model to evaluation mode
22
+ except Exception as e:
23
+ logger.error(f"Error loading model: {str(e)}")
24
+ raise
25
+
26
+ self.transform_image = transforms.Compose([
27
+ transforms.Resize(image_size),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(
30
+ mean=[0.485, 0.456, 0.406],
31
+ std=[0.229, 0.224, 0.225]
32
+ ),
33
+ ])
34
+
35
+ async def extract_object(self, image: Image.Image) -> Tuple[Image.Image, Image.Image]:
36
+ try:
37
+ # Transform image
38
+ input_images = self.transform_image(image).unsqueeze(0).to(self.device)
39
+
40
+ # Prediction
41
+ with torch.no_grad():
42
+ preds = self.birefnet(input_images)[-1].sigmoid().cpu()
43
+
44
+ pred = preds[0].squeeze()
45
+ pred_pil = transforms.ToPILImage()(pred)
46
+ mask = pred_pil.resize(image.size)
47
+ image.putalpha(mask)
48
+
49
+ return image, mask
50
+ except Exception as e:
51
+ logger.error(f"Error in extract_object: {str(e)}")
52
+ raise
53
+
54
+ async def segment(self, image: Image.Image) -> Tuple[Image.Image, Image.Image]:
55
+ """Fixed typo in method name and added type hints"""
56
+ return await self.extract_object(image)