Update main.py
Browse files
main.py
CHANGED
@@ -36,19 +36,32 @@ def read_root():
|
|
36 |
def health_check():
|
37 |
return {"status": "ok"}
|
38 |
|
39 |
-
|
40 |
-
|
41 |
@app.post("/extract", tags=["Extract Embeddings"])
|
42 |
async def extract_embeddings(file: UploadFile = File(...)):
|
43 |
# Load the image
|
44 |
contents = await file.read()
|
45 |
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Extract the face embeddings
|
51 |
-
|
|
|
52 |
|
53 |
return JSONResponse(content={"embeddings": embeddings})
|
54 |
|
|
|
36 |
def health_check():
|
37 |
return {"status": "ok"}
|
38 |
|
|
|
|
|
39 |
@app.post("/extract", tags=["Extract Embeddings"])
|
40 |
async def extract_embeddings(file: UploadFile = File(...)):
|
41 |
# Load the image
|
42 |
contents = await file.read()
|
43 |
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
44 |
|
45 |
+
# Detect faces
|
46 |
+
faces = mtcnn(image)
|
47 |
+
|
48 |
+
# Check if any faces were detected
|
49 |
+
if faces is None:
|
50 |
+
return JSONResponse(content={"error": "No faces detected in the image"})
|
51 |
+
|
52 |
+
# If faces is a list, take the first face. If it's a tensor, it's already the first (and only) face
|
53 |
+
if isinstance(faces, list):
|
54 |
+
face = faces[0]
|
55 |
+
else:
|
56 |
+
face = faces
|
57 |
+
|
58 |
+
# Ensure the face tensor is 4D (batch_size, channels, height, width)
|
59 |
+
if face.dim() == 3:
|
60 |
+
face = face.unsqueeze(0)
|
61 |
+
|
62 |
# Extract the face embeddings
|
63 |
+
with torch.no_grad():
|
64 |
+
embeddings = resnet(face).cpu().numpy().tolist()
|
65 |
|
66 |
return JSONResponse(content={"embeddings": embeddings})
|
67 |
|