File size: 1,454 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from .hourglass import FANAU


class Model:
    def __init__(self, npts=12, corenet='pretrained_models/disfa_adaptation_f0.pth', use_cuda=True):
        self.FAN = FANAU(num_modules=1, n_points=npts)
        self.FAN.load_state_dict(torch.load(corenet, map_location='cpu')['state_dict'])
        self.FAN.eval()
        if use_cuda:
            self.FAN.cuda()

    def __call__(self, x):
        H = self.FAN(x)
        H = H if H.__class__.__name__ == 'Tensor' else H[-1]
        return H

    def _forward_FAN(self, images):
        with torch.no_grad():
            self.FAN.eval()
            H = self.FAN(images)
        return H

    def forward_FAN(self, images):
        H = self.FAN(images)
        return H


class AUdetector:
    def __init__(self, au_model_path='models/pretrained/au_detector/disfa_adaptation_f0.pth', use_cuda=True):
        self.naus = 12
        self.AUdetector = Model(npts=self.naus, corenet=au_model_path, use_cuda=use_cuda)
        self.use_cuda = use_cuda

    def detect_AU(self, img):
        img_normalized = (img - img.min()) / (img.max() - img.min())
        if self.use_cuda:
            img_normalized = img_normalized.cuda()

        if img_normalized.ndim == 3:
            img_normalized = img_normalized.unsqueeze(0)

        heatmaps = self.AUdetector.forward_FAN(img_normalized)
        intensities = torch.nn.MaxPool2d((64, 64))(heatmaps).squeeze(2).squeeze(2)

        return intensities