Spaces:
Running
Running
import os | |
from fastapi import FastAPI, UploadFile | |
from fastapi.responses import JSONResponse | |
from fastapi.param_functions import File | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List | |
import io | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
import torch | |
from PIL import Image | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Set the cache directory to a writable location | |
os.environ['TORCH_HOME'] = '/tmp/.cache/torch' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
mtcnn = MTCNN(keep_all=True, device=device) | |
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device) | |
def read_root(): | |
return {"message": "Welcome to the face embeddings API!"} | |
def health_check(): | |
return {"status": "ok"} | |
async def extract_embeddings(file: UploadFile = File(...)): | |
# Load the image | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert('RGB') | |
# Detect faces | |
faces = mtcnn(image) | |
# Check if any faces were detected | |
if faces is None: | |
return JSONResponse(content={"error": "No faces detected in the image"}) | |
# If faces is a list, take the first face. If it's a tensor, it's already the first (and only) face | |
if isinstance(faces, list): | |
face = faces[0] | |
else: | |
face = faces | |
# Ensure the face tensor is 4D (batch_size, channels, height, width) | |
if face.dim() == 3: | |
face = face.unsqueeze(0) | |
# Extract the face embeddings | |
with torch.no_grad(): | |
embeddings = resnet(face).cpu().numpy().tolist() | |
return JSONResponse(content={"embeddings": embeddings}) | |
# @app.post("/extract") | |
# async def extract_embeddings(file: UploadFile = File(...)): | |
# # Load the image | |
# contents = await file.read() | |
# image = face_recognition.load_image_file(io.BytesIO(contents)) | |
# # Find all the faces in the image | |
# face_locations = face_recognition.face_locations(image) | |
# # Initialize an empty list to store the face embeddings | |
# embeddings = [] | |
# # Loop through each face location | |
# for face_location in face_locations: | |
# # Extract the face encoding | |
# face_encoding = face_recognition.face_encodings(image, [face_location])[0] | |
# # Append the face encoding to the embeddings list | |
# embeddings.append(face_encoding.tolist()) | |
# return JSONResponse(content={"embeddings": embeddings}) | |