Testys commited on
Commit
d36811d
1 Parent(s): e064690

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -6
main.py CHANGED
@@ -36,19 +36,32 @@ def read_root():
36
  def health_check():
37
  return {"status": "ok"}
38
 
39
-
40
-
41
  @app.post("/extract", tags=["Extract Embeddings"])
42
  async def extract_embeddings(file: UploadFile = File(...)):
43
  # Load the image
44
  contents = await file.read()
45
  image = Image.open(io.BytesIO(contents)).convert('RGB')
46
 
47
- # Preprocess the image
48
- preprocessed_image = mtcnn(image)
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Extract the face embeddings
51
- embeddings = resnet(preprocessed_image.unsqueeze(0)).detach().cpu().numpy().tolist()
 
52
 
53
  return JSONResponse(content={"embeddings": embeddings})
54
 
 
36
  def health_check():
37
  return {"status": "ok"}
38
 
 
 
39
  @app.post("/extract", tags=["Extract Embeddings"])
40
  async def extract_embeddings(file: UploadFile = File(...)):
41
  # Load the image
42
  contents = await file.read()
43
  image = Image.open(io.BytesIO(contents)).convert('RGB')
44
 
45
+ # Detect faces
46
+ faces = mtcnn(image)
47
+
48
+ # Check if any faces were detected
49
+ if faces is None:
50
+ return JSONResponse(content={"error": "No faces detected in the image"})
51
+
52
+ # If faces is a list, take the first face. If it's a tensor, it's already the first (and only) face
53
+ if isinstance(faces, list):
54
+ face = faces[0]
55
+ else:
56
+ face = faces
57
+
58
+ # Ensure the face tensor is 4D (batch_size, channels, height, width)
59
+ if face.dim() == 3:
60
+ face = face.unsqueeze(0)
61
+
62
  # Extract the face embeddings
63
+ with torch.no_grad():
64
+ embeddings = resnet(face).cpu().numpy().tolist()
65
 
66
  return JSONResponse(content={"embeddings": embeddings})
67