|
import torch |
|
from .hourglass import FANAU |
|
|
|
|
|
class Model: |
|
def __init__(self, npts=12, corenet='pretrained_models/disfa_adaptation_f0.pth', use_cuda=True): |
|
self.FAN = FANAU(num_modules=1, n_points=npts) |
|
self.FAN.load_state_dict(torch.load(corenet, map_location='cpu')['state_dict']) |
|
self.FAN.eval() |
|
if use_cuda: |
|
self.FAN.cuda() |
|
|
|
def __call__(self, x): |
|
H = self.FAN(x) |
|
H = H if H.__class__.__name__ == 'Tensor' else H[-1] |
|
return H |
|
|
|
def _forward_FAN(self, images): |
|
with torch.no_grad(): |
|
self.FAN.eval() |
|
H = self.FAN(images) |
|
return H |
|
|
|
def forward_FAN(self, images): |
|
H = self.FAN(images) |
|
return H |
|
|
|
|
|
class AUdetector: |
|
def __init__(self, au_model_path='models/pretrained/au_detector/disfa_adaptation_f0.pth', use_cuda=True): |
|
self.naus = 12 |
|
self.AUdetector = Model(npts=self.naus, corenet=au_model_path, use_cuda=use_cuda) |
|
self.use_cuda = use_cuda |
|
|
|
def detect_AU(self, img): |
|
img_normalized = (img - img.min()) / (img.max() - img.min()) |
|
if self.use_cuda: |
|
img_normalized = img_normalized.cuda() |
|
|
|
if img_normalized.ndim == 3: |
|
img_normalized = img_normalized.unsqueeze(0) |
|
|
|
heatmaps = self.AUdetector.forward_FAN(img_normalized) |
|
intensities = torch.nn.MaxPool2d((64, 64))(heatmaps).squeeze(2).squeeze(2) |
|
|
|
return intensities |
|
|