File size: 3,623 Bytes
f0ce5fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI, Request, UploadFile, File, HTTPException
import uvicorn
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from os import path, makedirs
from PIL import Image
import logging
from io import BytesIO
from typing import Tuple
from utils.image_segmenter import ImageSegmenter
import zipfile

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Image Background Remover",
             description="API for removing image backgrounds using ML",
             version="1.0.0")

# Define allowed origins explicitly for security
origins = [
    "http://127.0.0.1:5500",
    "http://localhost:5500"
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # Use the defined origins instead of "*"
    allow_credentials=True,
    allow_methods=["GET", "POST"],  # Specify only needed methods
    allow_headers=["*"],
)

# Set up the templates
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")

# Initialize ImageSegmenter once at startup
segmenter = ImageSegmenter()

@app.get("/")
async def index(request: Request) -> HTMLResponse:
    return templates.TemplateResponse("main.html", {"request": request})

@app.post("/api/remove-background/")
async def remove_background(file_obj: UploadFile = File(...)) -> StreamingResponse:
    try:
        # Validate file type
        if not file_obj.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        # Read image with error handling
        try:
            image_content = await file_obj.read()
            image = Image.open(BytesIO(image_content))
        except Exception as e:
            logger.error(f"Error reading image: {str(e)}")
            raise HTTPException(status_code=400, detail="Invalid image file")

        # Process image
        try:
            image, mask = await segmenter.segment(image)  # Fixed typo in method name
        except Exception as e:
            logger.error(f"Error processing image: {str(e)}")
            raise HTTPException(status_code=500, detail="Error processing image")

        # Create ZIP file in memory
        zip_buffer = BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            # Save processed image
            image_buffer = BytesIO()
            image.save(image_buffer, "PNG", optimize=True)
            image_buffer.seek(0)
            zip_file.writestr('processed_image.png', image_buffer.getvalue())
            
            # Save mask
            mask_buffer = BytesIO()
            mask.save(mask_buffer, "PNG", optimize=True)
            mask_buffer.seek(0)
            zip_file.writestr('mask.png', mask_buffer.getvalue())

        zip_buffer.seek(0)
        
        return StreamingResponse(
            zip_buffer,
            media_type="application/zip",
            headers={
                "Content-Disposition": f"attachment; filename=result_{file_obj.filename}.zip"
            }
        )
    except Exception as e:
        logger.error(f"Unexpected error: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

# if __name__ == "__main__":
#     uvicorn.run(
#         app, 
#         host="127.0.0.1", 
#         port=8000,
#         log_level="info",
#         reload=True  # Enable auto-reload during development
#     )