File size: 4,340 Bytes
046b3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import time

import cv2
import numpy as np

from .center_crop import center_crop
from .face_detector import FaceDetector


class VSNetModelPipeline:
    def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256):
        self.background_resize = background_resize
        self.no_detected_resize = no_detected_resize
        self.model = model
        self.face_detector = face_detector
        self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size)

    @staticmethod
    def create_circular_mask(h, w, power=None, clipping_coef=0.85):
        center = (int(w / 2), int(h / 2))

        Y, X = np.ogrid[:h, :w]
        dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
        print(dist_from_center.max(), dist_from_center.min())
        clipping_radius = min((h - center[0]), (w - center[1])) * clipping_coef
        max_size = max((h - center[0]), (w - center[1]))
        dist_from_center[dist_from_center < clipping_radius] = clipping_radius
        dist_from_center[dist_from_center > max_size] = max_size
        max_distance, min_distance = np.max(dist_from_center), np.min(dist_from_center)
        dist_from_center = 1 - (dist_from_center - min_distance) / (max_distance - min_distance)
        if power is not None:
            dist_from_center = np.power(dist_from_center, power)
        dist_from_center = np.stack([dist_from_center] * 3, axis=2)
        # mask = dist_from_center <= radius
        return dist_from_center


    @staticmethod
    def resize_size(image, size=720, always_apply=True):
        h, w, c = np.shape(image)
        if min(h, w) > size or always_apply:
            if h < w:
                h, w = int(size * h / w), size
            else:
                h, w = size, int(size * w / h)
        image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
        return image

    def normalize(self, img):
        img = img.astype(np.float32) / 255 * 2 - 1
        return img

    def denormalize(self, img):
        return (img + 1) / 2

    def divide_crop(self, img, must_divided=32):
        h, w, _ = img.shape
        h = h // must_divided * must_divided
        w = w // must_divided * must_divided

        img = center_crop(img, h, w)
        return img

    def merge_crops(self, faces_imgs, crops, full_image):
        for face, crop in zip(faces_imgs, crops):
            x1, y1, x2, y2 = crop
            W, H = x2 - x1, y2 - y1
            result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR)
            face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR)
            input_face = full_image[y1: y2, x1: x2]
            full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8)
        return full_image

    def __call__(self, img):
        return self.process_image(img)

    def process_image(self, img):
        img = self.resize_size(img, size=self.background_resize)
        img = self.divide_crop(img)

        face_crops, coords = self.face_detector(img)

        if len(face_crops) > 0:
            start_time = time.time()
            faces = self.normalize(face_crops)
            faces = faces.transpose(0, 3, 1, 2)
            out_faces = self.model(faces)
            out_faces = self.denormalize(out_faces)
            out_faces = out_faces.transpose(0, 2, 3, 1)
            out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8)
            end_time = time.time()
            logging.info(f'Face FPS {1 / (end_time - start_time)}')
        else:
            out_faces = []
            img = self.resize_size(img, size=self.no_detected_resize)
            img = self.divide_crop(img)

        start_time = time.time()
        full_image = self.normalize(img)
        full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2)
        full_image = self.model(full_image)
        full_image = self.denormalize(full_image)
        full_image = full_image.transpose(0, 2, 3, 1)
        full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8)
        end_time = time.time()
        logging.info(f'Background FPS {1 / (end_time - start_time)}')

        result_image = self.merge_crops(out_faces, coords, full_image[0])
        return result_image