File size: 9,584 Bytes
3d4d894
 
 
 
 
 
32a644b
be0162b
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a644b
 
 
3d4d894
be0162b
 
 
 
cd7e597
be0162b
 
 
 
 
cd7e597
be0162b
 
 
 
 
 
 
 
 
82d3087
 
 
 
be0162b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a644b
04b1201
32a644b
 
41e92f0
 
 
 
 
 
 
 
be0162b
41e92f0
 
 
 
 
 
 
 
 
 
 
 
 
be0162b
41e92f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54498b
3d4d894
 
 
 
 
be0162b
 
3d4d894
 
a54498b
3d4d894
 
 
 
 
 
 
 
 
 
 
a54498b
3d4d894
 
 
 
 
41e92f0
be0162b
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e92f0
 
3d4d894
 
b12d9cc
41e92f0
3d4d894
 
 
b12d9cc
41e92f0
 
 
 
b12d9cc
 
41e92f0
 
 
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e92f0
c12df56
3d4d894
41e92f0
b12d9cc
c12df56
41e92f0
 
 
 
 
 
 
 
c12df56
 
41e92f0
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e92f0
 
3d4d894
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""This file contains methods for inference and image generation."""
import logging
from typing import List, Tuple, Dict

import streamlit as st
import torch
import gc
import time
import numpy as np
from PIL import Image
from time import perf_counter
from contextlib import contextmanager
from scipy.signal import fftconvolve
from PIL import ImageFilter

from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from diffusers import ControlNetModel, UniPCMultistepScheduler
from diffusers import StableDiffusionInpaintPipeline

from config import WIDTH, HEIGHT
from palette import ade_palette
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline

LOGGING = logging.getLogger(__name__)

def flush():
    gc.collect()
    torch.cuda.empty_cache()

class ControlNetPipeline:
    def __init__(self):
        self.in_use = False
        self.controlnet = ControlNetModel.from_pretrained(
        "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)

        self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            controlnet=self.controlnet,
            safety_checker=None,
            torch_dtype=torch.float16
        )

        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe = self.pipe.to("cuda")
        
        self.waiting_queue = []
        self.count = 0
    
    @property
    def queue_size(self):
        return len(self.waiting_queue)
    
    def __call__(self, **kwargs):
        self.count += 1
        number = self.count

        self.waiting_queue.append(number)
        
        # wait until the next number in the queue is the current number
        while self.waiting_queue[0] != number:
            print(f"Wait for your turn {number} in queue {self.waiting_queue}")
            time.sleep(0.5)
            pass

        # it's your turn, so remove the number from the queue
        # and call the function
        print("It's the turn of", self.count)
        results = self.pipe(**kwargs)
        self.waiting_queue.pop(0)
        flush()
        return results
    
class SDPipeline:
    def __init__(self):
        self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-inpainting",
            torch_dtype=torch.float16,
            safety_checker=None,
        )

        self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe = self.pipe.to("cuda")
        
        self.waiting_queue = []
        self.count = 0
    
    @property
    def queue_size(self):
        return len(self.waiting_queue)
    
    def __call__(self, **kwargs):
        self.count += 1
        number = self.count

        self.waiting_queue.append(number)
        
        # wait until the next number in the queue is the current number
        while self.waiting_queue[0] != number:
            print(f"Wait for your turn {number} in queue {self.waiting_queue}")
            time.sleep(0.5)
            pass

        # it's your turn, so remove the number from the queue
        # and call the function
        print("It's the turn of", self.count)
        results = self.pipe(**kwargs)
        self.waiting_queue.pop(0)
        flush()
        return results


def convolution(mask: Image.Image, size=9) -> Image:
    """Method to blur the mask
    Args:
        mask (Image): masking image
        size (int, optional): size of the blur. Defaults to 9.
    Returns:
        Image: blurred mask
    """
    mask = np.array(mask.convert("L"))
    conv = np.ones((size, size)) / size**2
    mask_blended = fftconvolve(mask, conv, 'same')
    mask_blended = mask_blended.astype(np.uint8).copy()

    border = size

    # replace borders with original values
    mask_blended[:border, :] = mask[:border, :]
    mask_blended[-border:, :] = mask[-border:, :]
    mask_blended[:, :border] = mask[:, :border]
    mask_blended[:, -border:] = mask[:, -border:]

    return Image.fromarray(mask_blended).convert("L")


def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image:
    """Method to postprocess the inpainted image
    Args:
        inpainted (Image): inpainted image
        image (Image): original image
        mask (Image): mask
    Returns:
        Image: inpainted image
    """
    final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask)
    return final_inpainted.convert("RGB")


@st.experimental_singleton(max_entries=5)
def get_controlnet() -> ControlNetModel:
    """Method to load the controlnet model
    Returns:
        ControlNetModel: controlnet model
    """
    pipe = ControlNetPipeline()
    return pipe


@st.experimental_singleton(max_entries=5)
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
    """Method to load the segmentation pipeline
    Returns:
        Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
    """
    image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
    image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
        "openmmlab/upernet-convnext-small")
    return image_processor, image_segmentor


@st.experimental_singleton(max_entries=5)
def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline:
    """Method to load the inpainting pipeline
    Returns:
        StableDiffusionInpaintPipeline: inpainting pipeline
    """
    pipe = SDPipeline()
    return pipe


@torch.inference_mode()
def make_image_controlnet(image: np.ndarray,
                          mask_image: np.ndarray,
                          controlnet_conditioning_image: np.ndarray,
                          positive_prompt: str, negative_prompt: str,
                          seed: int = 2356132) -> List[Image.Image]:
    """Method to make image using controlnet
    Args:
        image (np.ndarray): input image
        mask_image (np.ndarray): mask image
        controlnet_conditioning_image (np.ndarray): conditioning image
        positive_prompt (str): positive prompt string
        negative_prompt (str): negative prompt string
        seed (int, optional): seed. Defaults to 2356132.
    Returns:
        List[Image.Image]: list of generated images
    """

    pipe = get_controlnet()
    flush()

    image = Image.fromarray(image).convert("RGB")
    controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB")#.filter(ImageFilter.GaussianBlur(radius = 9))
    mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
    mask_image_postproc = convolution(mask_image)


    st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
    generated_image = pipe(
        prompt=positive_prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=20,
        strength=1.00,
        guidance_scale=7.0,
        generator=[torch.Generator(device="cuda").manual_seed(seed)],
        image=image,
        mask_image=mask_image,
        controlnet_conditioning_image=controlnet_conditioning_image,
    ).images[0]
    generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)

    return generated_image


@torch.inference_mode()
def make_inpainting(positive_prompt: str,
                    image: Image,
                    mask_image: np.ndarray,
                    negative_prompt: str = "") -> List[Image.Image]:
    """Method to make inpainting
    Args:
        positive_prompt (str): positive prompt string
        image (Image): input image
        mask_image (np.ndarray): mask image
        negative_prompt (str, optional): negative prompt string. Defaults to "".
    Returns:
        List[Image.Image]: list of generated images
    """
    pipe = get_inpainting_pipeline()
    mask_image_postproc = convolution(mask_image)

    flush()
    st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
    generated_image = pipe(image=image,
                    mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)),
                    prompt=positive_prompt,
                    negative_prompt=negative_prompt,
                    num_inference_steps=20,
                    height=HEIGHT,
                    width=WIDTH,
                    **common_parameters
                    ).images[0]
    generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)

    return image_


@torch.inference_mode()
@torch.autocast('cuda')
def segment_image(image: Image) -> Image:
    """Method to segment image
    Args:
        image (Image): input image
    Returns:
        Image: segmented image
    """
    image_processor, image_segmentor = get_segmentation_pipeline()
    pixel_values = image_processor(image, return_tensors="pt").pixel_values
    with torch.no_grad():
        outputs = image_segmentor(pixel_values)

    seg = image_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]])[0]
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    palette = np.array(ade_palette())
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
    color_seg = color_seg.astype(np.uint8)
    seg_image = Image.fromarray(color_seg).convert('RGB')
    return seg_image