File size: 2,673 Bytes
e064690 b36b72b e064690 b36b72b d36811d b36b72b d36811d b36b72b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
from fastapi import FastAPI, UploadFile
from fastapi.responses import JSONResponse
from fastapi.param_functions import File
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import io
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from PIL import Image
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set the cache directory to a writable location
os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mtcnn = MTCNN(keep_all=True, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
@app.get("/", tags=["Home"])
def read_root():
return {"message": "Welcome to the face embeddings API!"}
@app.get("/health", tags=["Health"])
def health_check():
return {"status": "ok"}
@app.post("/extract", tags=["Extract Embeddings"])
async def extract_embeddings(file: UploadFile = File(...)):
# Load the image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
# Detect faces
faces = mtcnn(image)
# Check if any faces were detected
if faces is None:
return JSONResponse(content={"error": "No faces detected in the image"})
# If faces is a list, take the first face. If it's a tensor, it's already the first (and only) face
if isinstance(faces, list):
face = faces[0]
else:
face = faces
# Ensure the face tensor is 4D (batch_size, channels, height, width)
if face.dim() == 3:
face = face.unsqueeze(0)
# Extract the face embeddings
with torch.no_grad():
embeddings = resnet(face).cpu().numpy().tolist()
return JSONResponse(content={"embeddings": embeddings})
# @app.post("/extract")
# async def extract_embeddings(file: UploadFile = File(...)):
# # Load the image
# contents = await file.read()
# image = face_recognition.load_image_file(io.BytesIO(contents))
# # Find all the faces in the image
# face_locations = face_recognition.face_locations(image)
# # Initialize an empty list to store the face embeddings
# embeddings = []
# # Loop through each face location
# for face_location in face_locations:
# # Extract the face encoding
# face_encoding = face_recognition.face_encodings(image, [face_location])[0]
# # Append the face encoding to the embeddings list
# embeddings.append(face_encoding.tolist())
# return JSONResponse(content={"embeddings": embeddings})
|