dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
1.45 kB
import torch
from .hourglass import FANAU
class Model:
def __init__(self, npts=12, corenet='pretrained_models/disfa_adaptation_f0.pth', use_cuda=True):
self.FAN = FANAU(num_modules=1, n_points=npts)
self.FAN.load_state_dict(torch.load(corenet, map_location='cpu')['state_dict'])
self.FAN.eval()
if use_cuda:
self.FAN.cuda()
def __call__(self, x):
H = self.FAN(x)
H = H if H.__class__.__name__ == 'Tensor' else H[-1]
return H
def _forward_FAN(self, images):
with torch.no_grad():
self.FAN.eval()
H = self.FAN(images)
return H
def forward_FAN(self, images):
H = self.FAN(images)
return H
class AUdetector:
def __init__(self, au_model_path='models/pretrained/au_detector/disfa_adaptation_f0.pth', use_cuda=True):
self.naus = 12
self.AUdetector = Model(npts=self.naus, corenet=au_model_path, use_cuda=use_cuda)
self.use_cuda = use_cuda
def detect_AU(self, img):
img_normalized = (img - img.min()) / (img.max() - img.min())
if self.use_cuda:
img_normalized = img_normalized.cuda()
if img_normalized.ndim == 3:
img_normalized = img_normalized.unsqueeze(0)
heatmaps = self.AUdetector.forward_FAN(img_normalized)
intensities = torch.nn.MaxPool2d((64, 64))(heatmaps).squeeze(2).squeeze(2)
return intensities