|
import os |
|
from fastapi import FastAPI, UploadFile |
|
from fastapi.responses import JSONResponse |
|
from fastapi.param_functions import File |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from typing import List |
|
import io |
|
from facenet_pytorch import MTCNN, InceptionResnetV1 |
|
import torch |
|
from PIL import Image |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
os.environ['TORCH_HOME'] = '/tmp/.cache/torch' |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
mtcnn = MTCNN(keep_all=True, device=device) |
|
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device) |
|
|
|
|
|
@app.get("/", tags=["Home"]) |
|
def read_root(): |
|
return {"message": "Welcome to the face embeddings API!"} |
|
|
|
@app.get("/health", tags=["Health"]) |
|
def health_check(): |
|
return {"status": "ok"} |
|
|
|
@app.post("/extract", tags=["Extract Embeddings"]) |
|
async def extract_embeddings(file: UploadFile = File(...)): |
|
|
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)).convert('RGB') |
|
|
|
|
|
faces = mtcnn(image) |
|
|
|
|
|
if faces is None: |
|
return JSONResponse(content={"error": "No faces detected in the image"}) |
|
|
|
|
|
if isinstance(faces, list): |
|
face = faces[0] |
|
else: |
|
face = faces |
|
|
|
|
|
if face.dim() == 3: |
|
face = face.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
embeddings = resnet(face).cpu().numpy().tolist() |
|
|
|
return JSONResponse(content={"embeddings": embeddings}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|