import os import cv2 import time import glob import argparse import numpy as np from PIL import Image import torch from tqdm import tqdm from itertools import cycle from torch.multiprocessing import Pool, Process, set_start_method from facexlib.alignment import landmark_98_to_68 from facexlib.detection import init_detection_model from facexlib.utils import load_file_from_url from facexlib.alignment.awing_arch import FAN def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'awing_fan': model = FAN(num_modules=4, num_landmarks=98, device=device) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') model_path = load_file_from_url( url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) model.eval() model = model.to(device) return model class KeypointExtractor(): def __init__(self, device='cuda'): ### gfpgan/weights try: import webui # in webui root_path = 'extensions/SadTalker/gfpgan/weights' except: root_path = 'gfpgan/weights' root_path = 'gfpgan/weights' self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) def extract_keypoint(self, images, name=None, info=True): if isinstance(images, list): keypoints = [] if info: i_range = tqdm(images,desc='landmark Det:') else: i_range = images for image in i_range: current_kp = self.extract_keypoint(image) # current_kp = self.detector.get_landmarks(np.array(image)) if np.mean(current_kp) == -1 and keypoints: keypoints.append(keypoints[-1]) else: keypoints.append(current_kp[None]) keypoints = np.concatenate(keypoints, 0) np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints else: while True: try: with torch.no_grad(): # face detection -> face alignment. img = np.array(images) bboxes = self.det_net.detect_faces(images, 0.97) bboxes = bboxes[0] img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] #### keypoints to the original location keypoints[:,0] += int(bboxes[0]) keypoints[:,1] += int(bboxes[1]) break except RuntimeError as e: if str(e).startswith('CUDA'): print("Warning: out of memory, sleep for 1s") time.sleep(1) else: print(e) break except TypeError: print('No face detected in this image') shape = [68, 2] keypoints = -1. * np.ones(shape) break if name is not None: np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints def read_video(filename): frames = [] cap = cv2.VideoCapture(filename) while cap.isOpened(): ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) frames.append(frame) else: break cap.release() return frames def run(data): filename, opt, device = data os.environ['CUDA_VISIBLE_DEVICES'] = device kp_extractor = KeypointExtractor() images = read_video(filename) name = filename.split('/')[-2:] os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) kp_extractor.extract_keypoint( images, name=os.path.join(opt.output_dir, name[-2], name[-1]) ) if __name__ == '__main__': set_start_method('spawn') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--input_dir', type=str, help='the folder of the input files') parser.add_argument('--output_dir', type=str, help='the folder of the output files') parser.add_argument('--device_ids', type=str, default='0,1') parser.add_argument('--workers', type=int, default=4) opt = parser.parse_args() filenames = list() VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) extensions = VIDEO_EXTENSIONS for ext in extensions: os.listdir(f'{opt.input_dir}') print(f'{opt.input_dir}/*.{ext}') filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) print('Total number of videos:', len(filenames)) pool = Pool(opt.workers) args_list = cycle([opt]) device_ids = opt.device_ids.split(",") device_ids = cycle(device_ids) for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): None