from fastapi import FastAPI, Request, UploadFile, File, HTTPException import uvicorn from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from os import path, makedirs from PIL import Image import logging from io import BytesIO from typing import Tuple from utils.image_segmenter import ImageSegmenter import zipfile # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Image Background Remover", description="API for removing image backgrounds using ML", version="1.0.0") # Define allowed origins explicitly for security origins = [ "http://127.0.0.1:5500", "http://localhost:5500" ] app.add_middleware( CORSMiddleware, allow_origins=origins, # Use the defined origins instead of "*" allow_credentials=True, allow_methods=["GET", "POST"], # Specify only needed methods allow_headers=["*"], ) # Set up the templates templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") # Initialize ImageSegmenter once at startup segmenter = ImageSegmenter() @app.get("/") async def index(request: Request) -> HTMLResponse: return templates.TemplateResponse("main.html", {"request": request}) @app.post("/api/remove-background/") async def remove_background(file_obj: UploadFile = File(...)) -> StreamingResponse: try: # Validate file type if not file_obj.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read image with error handling try: image_content = await file_obj.read() image = Image.open(BytesIO(image_content)) except Exception as e: logger.error(f"Error reading image: {str(e)}") raise HTTPException(status_code=400, detail="Invalid image file") # Process image try: image, mask = await segmenter.segment(image) # Fixed typo in method name except Exception as e: logger.error(f"Error processing image: {str(e)}") raise HTTPException(status_code=500, detail="Error processing image") # Create ZIP file in memory zip_buffer = BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: # Save processed image image_buffer = BytesIO() image.save(image_buffer, "PNG", optimize=True) image_buffer.seek(0) zip_file.writestr('processed_image.png', image_buffer.getvalue()) # Save mask mask_buffer = BytesIO() mask.save(mask_buffer, "PNG", optimize=True) mask_buffer.seek(0) zip_file.writestr('mask.png', mask_buffer.getvalue()) zip_buffer.seek(0) return StreamingResponse( zip_buffer, media_type="application/zip", headers={ "Content-Disposition": f"attachment; filename=result_{file_obj.filename}.zip" } ) except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") # if __name__ == "__main__": # uvicorn.run( # app, # host="127.0.0.1", # port=8000, # log_level="info", # reload=True # Enable auto-reload during development # )