File size: 3,995 Bytes
319d3b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, MODEL_DIR

TMP_DIR = "./__pycache__"


@dataclass
class Cfg:
    detector_weights: str
    checkpoint: str
    device: str = "cpu"
    with_persons: bool = True
    disable_faces: bool = False
    draw: bool = True


class ValidImgDetector:
    predictor = None

    def __init__(self):
        detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
        age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
        predictor_cfg = Cfg(detector_path, age_gender_path)
        self.predictor = Predictor(predictor_cfg)

    def _detect(
        self,
        image: np.ndarray,
        score_threshold: float,
        iou_threshold: float,
        mode: str,
        predictor: Predictor,
    ) -> np.ndarray:
        # input is rgb image, output must be rgb too
        predictor.detector.detector_kwargs["conf"] = score_threshold
        predictor.detector.detector_kwargs["iou"] = iou_threshold
        if mode == "Use persons and faces":
            use_persons = True
            disable_faces = False

        elif mode == "Use persons only":
            use_persons = True
            disable_faces = True

        elif mode == "Use faces only":
            use_persons = False
            disable_faces = False

        predictor.age_gender_model.meta.use_persons = use_persons
        predictor.age_gender_model.meta.disable_faces = disable_faces
        # image = image[:, :, ::-1]  # RGB -> BGR
        detected_objects, out_im = predictor.recognize(image)
        has_child, has_female, has_male = False, False, False
        if len(detected_objects.ages) > 0:
            has_child = min(detected_objects.ages) < 18
            has_female = "female" in detected_objects.genders
            has_male = "male" in detected_objects.genders

        return out_im[:, :, ::-1], has_child, has_female, has_male

    def valid_img(self, img_path):
        image = cv2.imread(img_path)
        return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)


def infer(photo: str):
    if is_url(photo):
        if os.path.exists(TMP_DIR):
            shutil.rmtree(TMP_DIR)

        photo = download_file(photo, f"{TMP_DIR}/download.jpg")

    detector = ValidImgDetector()
    if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
        return None, None, None, "请正确输入图片 Please input image correctly"

    return detector.valid_img(photo)


if __name__ == "__main__":
    with gr.Blocks() as iface:
        warnings.filterwarnings("ignore")
        with gr.Tab("上传模式 Upload Mode"):
            gr.Interface(
                fn=infer,
                inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"),
                outputs=[
                    gr.Image(label="检测结果 Detection Result", type="numpy"),
                    gr.Textbox(label="存在儿童 Has Child"),
                    gr.Textbox(label="存在女性 Has Female"),
                    gr.Textbox(label="存在男性 Has Male"),
                ],
                examples=get_jpg_files(f"{MODEL_DIR}/examples"),
                allow_flagging="never",
            )

        with gr.Tab("在线模式 Online Mode"):
            gr.Interface(
                fn=infer,
                inputs=gr.Textbox(label="网络图片链接 Online Picture URL"),
                outputs=[
                    gr.Image(label="检测结果 Detection Result", type="numpy"),
                    gr.Textbox(label="存在儿童 Has Child"),
                    gr.Textbox(label="存在女性 Has Female"),
                    gr.Textbox(label="存在男性 Has Male"),
                ],
                allow_flagging="never",
                cache_examples=False,
            )

    iface.launch()