File size: 12,225 Bytes
3d4d894
 
 
 
 
 
 
 
 
dd0ab9f
 
f8c7d9d
3d4d894
 
 
 
 
 
 
 
 
 
 
 
71af695
 
 
 
3790166
 
71af695
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71af695
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef697d2
3d4d894
7ea369b
3d4d894
 
 
 
9bfe550
 
 
3d4d894
 
 
9bfe550
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
9bfe550
 
 
 
 
 
3d4d894
9bfe550
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71af695
 
 
 
3d4d894
71af695
 
 
 
 
 
 
 
 
 
 
 
 
 
c882d5b
71af695
 
 
 
 
 
3d4d894
 
 
2ea3bd8
0cae6b6
 
6e3a1b8
0cae6b6
5a67d9b
 
2379311
5a67d9b
 
 
 
e97c302
 
 
 
 
 
 
5a67d9b
13e5061
2379311
8604dfb
0d84e52
9bfe550
5a67d9b
3d4d894
 
 
f6b9c19
 
 
 
 
 
 
3d4d894
 
 
 
d0613b1
3d4d894
 
 
cbc9e9a
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c882d5b
3d4d894
 
b999def
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time

from models import make_image_controlnet, make_inpainting
from segmentation import segment_image
from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb
from palette import COLOR_MAPPING_CATEGORY
from preprocessing import preprocess_seg_mask, get_image, get_mask

# wide layout
st.set_page_config(layout="wide")


def on_upload() -> None:
    """Upload image to the canvas."""
    if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
        image = Image.open(st.session_state['input_image']).convert('RGB')
        st.session_state['initial_image'] = image
        if 'seg' in st.session_state:
            del st.session_state['seg']
        if 'unique_colors' in st.session_state:
            del st.session_state['unique_colors']
        if 'output_image' in st.session_state:
            del st.session_state['output_image']


def check_reset_state() -> bool:
    """Check whether the UI elements need to be reset
    Returns:
        bool: True if the UI elements need to be reset, False otherwise
    """
    if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']):
        st.session_state['reset_canvas'] = False
        return True
    st.session_state['reset_canvas'] = False
    return False


def move_image(source: Union[str, Image.Image],
               dest: str,
               rerun: bool = True,
               remove_state: bool = True) -> None:
    """Move image from source to destination.
    Args:
        source (Union[str, Image.Image]): source image
        dest (str): destination image location
        rerun (bool, optional): rerun streamlit. Defaults to True.
        remove_state (bool, optional): remove the canvas state. Defaults to True.
    """
    source_image = source if isinstance(source, Image.Image) else st.session_state[source]

    if remove_state:
        st.session_state['reset_canvas'] = True
        if 'seg' in st.session_state:
            del st.session_state['seg']
        if 'unique_colors' in st.session_state:
            del st.session_state['unique_colors']

    st.session_state[dest] = source_image
    if rerun:
        st.experimental_rerun()


def on_change_radio() -> None:
    """Reset the UI elements when the radio button is changed."""
    st.session_state['reset_canvas'] = True


def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
    canvas_dict = dict(
        fill_color=canvas_color,
        stroke_color=canvas_color,
        background_color="#FFFFFF",
        background_image=st.session_state['initial_image'] if 'initial_image' in st.session_state else None,
        stroke_width=brush,
        initial_drawing={'version': '4.4.0', 'objects': []} if _reset_state else None,
        update_streamlit=True,
        height=512,
        width=512,
        drawing_mode=paint_mode,
        key="canvas",
    )
    return canvas_dict  

def make_prompt_row():
    col_0_0, col_0_1 = st.columns(2)
    with col_0_0:
        st.text_input(label="Positive prompt", value="a photograph of a room, interior design, 4k, high resolution", key='positive_prompt')
    with col_0_1:
        st.text_input(label="Negative prompt", value="lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly", key='negative_prompt')

def make_sidebar():
    with st.sidebar:
        input_image = st.file_uploader("", type=["png", "jpg"], key='input_image', on_change=on_upload)
        generation_mode = st.selectbox("Generation mode", ["Re-generate objects",
                                                           "Segmentation conditioning",
                                                           "Inpainting"], on_change=on_change_radio)


        if generation_mode == "Segmentation conditioning":
            paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon"))
            if paint_mode == "freedraw":
                brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg')
            else:
                brush = 5
    
            category_chooser = st.sidebar.selectbox("Filter on category", list(
                COLOR_MAPPING_CATEGORY.keys()), index=0, key='category_chooser')

            chosen_colors = list(COLOR_MAPPING_CATEGORY[category_chooser].keys())

            color_chooser = st.sidebar.selectbox(
                "Choose a color", chosen_colors, index=0, format_func=map_colors, key='color_chooser'
            )

        elif generation_mode == "Re-generate objects":
            color_chooser = "rgba(0, 0, 0, 0.0)"
            paint_mode = 'freedraw'
            brush = 0

        else:
            paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon"))
            if paint_mode == "freedraw":
                brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg')
            else:
                brush = 5

            color_chooser = "#000000"
    return input_image, generation_mode, brush, color_chooser, paint_mode


def make_output_image():
    if 'output_image' in st.session_state:
        output_image = st.session_state['output_image']
        if isinstance(output_image, np.ndarray):
            output_image = Image.fromarray(output_image)

        if isinstance(output_image, Image.Image):
            output_image = output_image.resize((512, 512))
    else:
        output_image = Image.new('RGB', (512, 512), (255, 255, 255))

    st.write("#### Output image")
    st.image(output_image, width=512)
    if st.button("Move to input image"):
        move_image('output_image', 'initial_image', remove_state=True, rerun=True)

def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
    st.write("#### Input image")
    canvas_dict = make_canvas_dict(
        canvas_color=canvas_color,
        paint_mode=paint_mode,
        brush=brush,
        _reset_state=_reset_state
    )
    if generation_mode == "Segmentation conditioning":
        canvas = st_canvas(
            **canvas_dict,
        )

        if st.button("generate image", key='generate_button'):
            image = get_image()
            print("Preparing image segmentation")
            real_seg = segment_image(Image.fromarray(image))
            mask, seg = preprocess_seg_mask(canvas, real_seg)

            with st.spinner(text="Generating image"):
                print("Making image")
                result_image = make_image_controlnet(image=image,
                                                        mask_image=mask,
                                                        controlnet_conditioning_image=seg,
                                                        positive_prompt=st.session_state['positive_prompt'],
                                                        negative_prompt=st.session_state['negative_prompt'],
                                                        seed=random.randint(0, 100000) # nosec
                                                        )
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image


    elif generation_mode == "Re-generate objects":
        canvas = st_canvas(
            **canvas_dict,
        )
        if 'seg' not in st.session_state:
            with st.spinner(text="Preparing image segmentation"):
                image = get_image()
                real_seg = np.array(segment_image(Image.fromarray(image)))
                st.session_state['seg'] = real_seg

        if 'unique_colors' not in st.session_state:
            real_seg = st.session_state['seg']
            unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
            unique_colors = [tuple(color) for color in unique_colors]
            st.session_state['unique_colors'] = unique_colors

        with st.expander("Explanation", expanded=True):
            st.write("This mode allows you to choose which objects you want to re-generate in the image. "
                 "Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
                 " to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
                 " the 'move image to input' button."
                 )
            
        chosen_colors = st.multiselect(
            label="Choose which concepts you want to regenerate in the image",
            options=st.session_state['unique_colors'],
            key='chosen_colors',
            default=st.session_state['unique_colors'],
            format_func=map_colors_rgb,
        )

        if st.button("generate image", key='generate_button'):
            image = get_image()
            print(chosen_colors)

            segmentation = st.session_state['seg']
            mask = np.zeros_like(segmentation)
            for color in chosen_colors:
                # if the color is in the segmentation, set mask to 1
                mask[np.where((segmentation == color).all(axis=2))] = 1

            with st.spinner(text="Generating image"):
                result_image = make_image_controlnet(image=image,
                                                        mask_image=mask,
                                                        controlnet_conditioning_image=segmentation,
                                                        positive_prompt=st.session_state['positive_prompt'],
                                                        negative_prompt=st.session_state['negative_prompt'],
                                                        seed=random.randint(0, 100000) # nosec
                                                        )
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

    elif generation_mode == "Inpainting":
        image = get_image()

        canvas = st_canvas(
            **canvas_dict,
        )

        if st.button("generate images", key='generate_button'):
            canvas_mask = canvas.image_data
            if not isinstance(canvas_mask, np.ndarray):
                canvas_mask = np.array(canvas_mask)
            mask = get_mask(canvas_mask)

            with st.spinner(text="Generating new images"):
                print("Making image")
                result_image = make_inpainting(positive_prompt=st.session_state['positive_prompt'],
                                                image=Image.fromarray(image),
                                                mask_image=mask,
                                                negative_prompt=st.session_state['negative_prompt'],
                                                )
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

def main():
    # center text
    st.write("## Controlnet sprint - interior design", unsafe_allow_html=True)

    input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()

    # check if there is an input_image
    if not ('input_image' in st.session_state and st.session_state['input_image'] is not None):
        print("Image not present")
        st.success("Upload an image to start")
    else:
        make_prompt_row()

        _reset_state = check_reset_state()

        col1, col2 = st.columns(2)
        with col1:
            make_editing_canvas(canvas_color=color_chooser,
                                brush=brush,
                                _reset_state=_reset_state,
                                generation_mode=generation_mode,
                                paint_mode=paint_mode
                                )

        with col2:
            make_output_image()
            

if __name__ == "__main__":
    main()