File size: 6,259 Bytes
29d49a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
This file contains a Processor that can be used to process images with controlnet aux processors
"""
import io
import logging
from typing import Dict, Optional, Union

from PIL import Image

from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
                            LeresDetector, LineartAnimeDetector,
                            LineartDetector, MediapipeFaceDetector,
                            MidasDetector, MLSDdetector, NormalBaeDetector,
                            OpenposeDetector, PidiNetDetector, ZoeDetector,
                            DWposeDetector)

LOGGER = logging.getLogger(__name__)


MODELS = {
    # checkpoint models
    'scribble_hed': {'class': HEDdetector, 'checkpoint': True},
    'softedge_hed': {'class': HEDdetector, 'checkpoint': True},
    'scribble_hedsafe': {'class': HEDdetector, 'checkpoint': True},
    'softedge_hedsafe': {'class': HEDdetector, 'checkpoint': True},
    'depth_midas': {'class': MidasDetector, 'checkpoint': True},
    'mlsd': {'class': MLSDdetector, 'checkpoint': True},
    'openpose': {'class': OpenposeDetector, 'checkpoint': True},
    'openpose_face': {'class': OpenposeDetector, 'checkpoint': True},
    'openpose_faceonly': {'class': OpenposeDetector, 'checkpoint': True},
    'openpose_full': {'class': OpenposeDetector, 'checkpoint': True},
    'openpose_hand': {'class': OpenposeDetector, 'checkpoint': True},
    'dwpose': {'class': DWposeDetector, 'checkpoint': True},
    'scribble_pidinet': {'class': PidiNetDetector, 'checkpoint': True},
    'softedge_pidinet': {'class': PidiNetDetector, 'checkpoint': True},
    'scribble_pidsafe': {'class': PidiNetDetector, 'checkpoint': True},
    'softedge_pidsafe': {'class': PidiNetDetector, 'checkpoint': True},
    'normal_bae': {'class': NormalBaeDetector, 'checkpoint': True},
    'lineart_coarse': {'class': LineartDetector, 'checkpoint': True},
    'lineart_realistic': {'class': LineartDetector, 'checkpoint': True},
    'lineart_anime': {'class': LineartAnimeDetector, 'checkpoint': True},
    'depth_zoe': {'class': ZoeDetector, 'checkpoint': True}, 
    'depth_leres': {'class': LeresDetector, 'checkpoint': True}, 
    'depth_leres++': {'class': LeresDetector, 'checkpoint': True}, 
    # instantiate
    'shuffle': {'class': ContentShuffleDetector, 'checkpoint': False},
    'mediapipe_face': {'class': MediapipeFaceDetector, 'checkpoint': False},
    'canny': {'class': CannyDetector, 'checkpoint': False},
}


MODEL_PARAMS = {
    'scribble_hed': {'scribble': True},
    'softedge_hed': {'scribble': False},
    'scribble_hedsafe': {'scribble': True, 'safe': True},
    'softedge_hedsafe': {'scribble': False, 'safe': True},
    'depth_midas': {},
    'mlsd': {},
    'openpose': {'include_body': True, 'include_hand': False, 'include_face': False},
    'openpose_face': {'include_body': True, 'include_hand': False, 'include_face': True},
    'openpose_faceonly': {'include_body': False, 'include_hand': False, 'include_face': True},
    'openpose_full': {'include_body': True, 'include_hand': True, 'include_face': True},
    'openpose_hand': {'include_body': False, 'include_hand': True, 'include_face': False},
    'dwpose': {},
    'scribble_pidinet': {'safe': False, 'scribble': True},
    'softedge_pidinet': {'safe': False, 'scribble': False},
    'scribble_pidsafe': {'safe': True, 'scribble': True},
    'softedge_pidsafe': {'safe': True, 'scribble': False},
    'normal_bae': {},
    'lineart_realistic': {'coarse': False},
    'lineart_coarse': {'coarse': True},
    'lineart_anime': {},
    'canny': {},
    'shuffle': {},
    'depth_zoe': {},
    'depth_leres': {'boost': False},
    'depth_leres++': {'boost': True},
    'mediapipe_face': {},
}

CHOICES = f"Choices for the processor are {list(MODELS.keys())}"


class Processor:
    def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None:
        """Processor that can be used to process images with controlnet aux processors

        Args:
            processor_id (str): processor name, options are 'hed, midas, mlsd, openpose,
                                pidinet, normalbae, lineart, lineart_coarse, lineart_anime,
                                canny, content_shuffle, zoe, mediapipe_face
            params (Optional[Dict]): parameters for the processor
        """
        LOGGER.info(f"Loading {processor_id}")

        if processor_id not in MODELS:
            raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}")

        self.processor_id = processor_id
        self.processor = self.load_processor(self.processor_id)

        # load default params
        self.params = MODEL_PARAMS[self.processor_id]
        # update with user params
        if params:
            self.params.update(params)

    def load_processor(self, processor_id: str) -> 'Processor':
        """Load controlnet aux processors

        Args:
            processor_id (str): processor name

        Returns:
            Processor: controlnet aux processor
        """
        processor = MODELS[processor_id]['class']

        # check if the proecssor is a checkpoint model
        if MODELS[processor_id]['checkpoint']:
            processor = processor.from_pretrained("lllyasviel/Annotators")
        else:
            processor = processor()
        return processor

    def __call__(self, image: Union[Image.Image, bytes],
                 to_pil: bool = True) -> Union[Image.Image, bytes]:
        """processes an image with a controlnet aux processor

        Args:
            image (Union[Image.Image, bytes]): input image in bytes or PIL Image
            to_pil (bool): whether to return bytes or PIL Image

        Returns:
            Union[Image.Image, bytes]: processed image in bytes or PIL Image
        """
        # check if bytes or PIL Image
        if isinstance(image, bytes):
            image = Image.open(io.BytesIO(image)).convert("RGB")

        processed_image = self.processor(image, **self.params)

        if to_pil:
            return processed_image
        else:
            output_bytes = io.BytesIO()
            processed_image.save(output_bytes, format='JPEG')
            return output_bytes.getvalue()