Spaces:
Running
Running
from fastapi import FastAPI, Request, UploadFile, File, HTTPException | |
import uvicorn | |
from fastapi.responses import HTMLResponse, StreamingResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from os import path, makedirs | |
from PIL import Image | |
import logging | |
from io import BytesIO | |
from typing import Tuple | |
from utils.image_segmenter import ImageSegmenter | |
import zipfile | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Image Background Remover", | |
description="API for removing image backgrounds using ML", | |
version="1.0.0") | |
# Define allowed origins explicitly for security | |
origins = [ | |
"http://127.0.0.1:5500", | |
"http://localhost:5500" | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # Use the defined origins instead of "*" | |
allow_credentials=True, | |
allow_methods=["GET", "POST"], # Specify only needed methods | |
allow_headers=["*"], | |
) | |
# Set up the templates | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize ImageSegmenter once at startup | |
segmenter = ImageSegmenter() | |
async def index(request: Request) -> HTMLResponse: | |
return templates.TemplateResponse("main.html", {"request": request}) | |
async def remove_background(file_obj: UploadFile = File(...)) -> StreamingResponse: | |
try: | |
# Validate file type | |
if not file_obj.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read image with error handling | |
try: | |
image_content = await file_obj.read() | |
image = Image.open(BytesIO(image_content)) | |
except Exception as e: | |
logger.error(f"Error reading image: {str(e)}") | |
raise HTTPException(status_code=400, detail="Invalid image file") | |
# Process image | |
try: | |
image, mask = await segmenter.segment(image) # Fixed typo in method name | |
except Exception as e: | |
logger.error(f"Error processing image: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error processing image") | |
# Create ZIP file in memory | |
zip_buffer = BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
# Save processed image | |
image_buffer = BytesIO() | |
image.save(image_buffer, "PNG", optimize=True) | |
image_buffer.seek(0) | |
zip_file.writestr('processed_image.png', image_buffer.getvalue()) | |
# Save mask | |
mask_buffer = BytesIO() | |
mask.save(mask_buffer, "PNG", optimize=True) | |
mask_buffer.seek(0) | |
zip_file.writestr('mask.png', mask_buffer.getvalue()) | |
zip_buffer.seek(0) | |
return StreamingResponse( | |
zip_buffer, | |
media_type="application/zip", | |
headers={ | |
"Content-Disposition": f"attachment; filename=result_{file_obj.filename}.zip" | |
} | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
# if __name__ == "__main__": | |
# uvicorn.run( | |
# app, | |
# host="127.0.0.1", | |
# port=8000, | |
# log_level="info", | |
# reload=True # Enable auto-reload during development | |
# ) |