File size: 1,943 Bytes
18f531e
 
 
 
 
60cc4ec
18f531e
 
 
 
 
 
 
60cc4ec
 
18f531e
60cc4ec
18f531e
60cc4ec
 
 
 
 
 
d6866b9
60cc4ec
 
 
d6866b9
 
 
 
60cc4ec
d6866b9
60cc4ec
 
 
 
18f531e
 
60cc4ec
d6866b9
 
60cc4ec
18f531e
 
 
d6866b9
 
 
 
 
 
 
 
 
60cc4ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import os
import torch
from facenet_pytorch import MTCNN, InceptionResnetV1
import logging
from PIL import Image

logger = logging.getLogger(__name__)

class FacialProcessing:
    def __init__(self):
        os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.mtcnn = MTCNN(keep_all=True, device=self.device)
        self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(self.device)

    def extract_embeddings_vgg(self, image_path):
        try:
            img = Image.open(image_path)
            img = img.convert('RGB')
            
            # Detect faces
            boxes, _ = self.mtcnn.detect(img)
            
            if boxes is None or len(boxes) == 0:
                logger.warning(f"No face detected in image: {image_path}")
                return None
            
            if len(boxes) > 1:
                logger.warning(f"Multiple faces detected in image: {image_path}")
                return None
            
            # Get the largest face
            largest_box = boxes[0]
            face = self.mtcnn(img, return_prob=False)
            
            if face is None:
                logger.warning(f"Failed to align face in image: {image_path}")
                return None
            
            # Extract embeddings
            with torch.no_grad():
                embeddings = self.resnet(face).cpu().numpy().flatten()
            return embeddings.tolist()
            
        except Exception as e:
            logger.error(f"An error occurred while extracting embeddings: {e}")
            return None

    def preprocess_image(self, image_path):
        try:
            img = Image.open(image_path)
            img = img.convert('RGB')
            return img
        except Exception as e:
            logger.error(f"Error opening image: {e}")
            return None