image-segment / app.py
Invictus-Jai's picture
Add Application Files
f0ce5fe
from fastapi import FastAPI, Request, UploadFile, File, HTTPException
import uvicorn
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from os import path, makedirs
from PIL import Image
import logging
from io import BytesIO
from typing import Tuple
from utils.image_segmenter import ImageSegmenter
import zipfile
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Image Background Remover",
description="API for removing image backgrounds using ML",
version="1.0.0")
# Define allowed origins explicitly for security
origins = [
"http://127.0.0.1:5500",
"http://localhost:5500"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Use the defined origins instead of "*"
allow_credentials=True,
allow_methods=["GET", "POST"], # Specify only needed methods
allow_headers=["*"],
)
# Set up the templates
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize ImageSegmenter once at startup
segmenter = ImageSegmenter()
@app.get("/")
async def index(request: Request) -> HTMLResponse:
return templates.TemplateResponse("main.html", {"request": request})
@app.post("/api/remove-background/")
async def remove_background(file_obj: UploadFile = File(...)) -> StreamingResponse:
try:
# Validate file type
if not file_obj.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image with error handling
try:
image_content = await file_obj.read()
image = Image.open(BytesIO(image_content))
except Exception as e:
logger.error(f"Error reading image: {str(e)}")
raise HTTPException(status_code=400, detail="Invalid image file")
# Process image
try:
image, mask = await segmenter.segment(image) # Fixed typo in method name
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing image")
# Create ZIP file in memory
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Save processed image
image_buffer = BytesIO()
image.save(image_buffer, "PNG", optimize=True)
image_buffer.seek(0)
zip_file.writestr('processed_image.png', image_buffer.getvalue())
# Save mask
mask_buffer = BytesIO()
mask.save(mask_buffer, "PNG", optimize=True)
mask_buffer.seek(0)
zip_file.writestr('mask.png', mask_buffer.getvalue())
zip_buffer.seek(0)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename=result_{file_obj.filename}.zip"
}
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# if __name__ == "__main__":
# uvicorn.run(
# app,
# host="127.0.0.1",
# port=8000,
# log_level="info",
# reload=True # Enable auto-reload during development
# )