michaelapplydesign commited on
Commit
13cb3ce
·
1 Parent(s): 54588b3
Files changed (14) hide show
  1. __init__.py +0 -0
  2. app.py +56 -4
  3. colors.py +343 -0
  4. config.py +35 -0
  5. empty_room.jpg +0 -0
  6. explanation.py +51 -0
  7. helpers.py +47 -0
  8. models.py +98 -0
  9. palette.py +38 -0
  10. pipelines.py +126 -0
  11. preprocessing.py +134 -0
  12. requirements.txt +10 -10
  13. segmentation.py +55 -0
  14. stable_diffusion_controlnet_inpaint_img2img.py +1112 -0
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,7 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "V3 Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch(share=True)
 
1
+ # import gradio as gr
2
+ #
3
+ # def greet(name):
4
+ # return "V5 Hello " + name + "!!"
5
+ #
6
+ # iface = gr.Interface(
7
+ # fn=greet,
8
+ # inputs="text",
9
+ # outputs="text",
10
+ # title="MB TEST 1",
11
+ # )
12
+ # iface.launch(share=True)
13
+
14
  import gradio as gr
15
+ from models import make_inpainting
16
+ import io
17
+ from PIL import Image
18
+ import numpy as np
19
+
20
+ # from transformers import pipeline
21
+ #
22
+ # pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
23
+
24
+ def image_to_byte_array(image: Image) -> bytes:
25
+ # BytesIO is a fake file stored in memory
26
+ imgByteArr = io.BytesIO()
27
+ # image.save expects a file as a argument, passing a bytes io ins
28
+ image.save(imgByteArr, format='png') # image.format
29
+ # Turn the BytesIO object back into a bytes object
30
+ imgByteArr = imgByteArr.getvalue()
31
+ return imgByteArr
32
+
33
+ def predict(input_img1,input_img2):
34
+
35
+ # image = Image.open(requests.get("https://applydesignblobs-chh5aahjdzh0cnew.z01.azurefd.net/spaceimages/org_sqr_7fee0869-3187-4363-b5fb-5233e943649d.png", stream=True).raw)
36
+ # mask = Image.open(requests.get("https://applydesign.blob.core.windows.net/spaceimages/mask_e85b1585-8.png", stream=True).raw)
37
+
38
+ result_image = make_inpainting(positive_prompt='test1',
39
+ image=image_to_byte_array(input_img1),
40
+ mask_image=np.array(input_img2),
41
+ negative_prompt="xxx",
42
+ )
43
+
44
+
45
+ # predictions = pipeline(input_img1)
46
+ return input_img1
47
+
48
+ gradio_app = gr.Interface(
49
+ predict,
50
+ inputs=[gr.Image(label="img", sources=['upload', 'webcam'], type="pil"),
51
+ gr.Image(label="mask", sources=['upload', 'webcam'], type="pil")
52
+ ],
53
+ outputs= gr.Image(label="resp"),
54
+ title="rem fur 1",
55
+ )
56
+
57
 
58
+ gradio_app.launch(share=True)
 
59
 
 
 
colors.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Color mappings"""
2
+ from typing import List, Dict
3
+
4
+ TRIVIA = {
5
+ "#B47878": "building;edifice",
6
+ "#06E6E6": "sky",
7
+ "#04C803": "tree",
8
+ "#8C8C8C": "road;route",
9
+ "#04FA07": "grass",
10
+ "#96053D": "person;individual;someone;somebody;mortal;soul",
11
+ "#CCFF04": "plant;flora;plant;life",
12
+ "#787846": "earth;ground",
13
+ "#FF09E0": "house",
14
+ "#0066C8": "car;auto;automobile;machine;motorcar",
15
+ "#3DE6FA": "water",
16
+ "#FF3D06": "railing;rail",
17
+ "#FF5C00": "arcade;machine",
18
+ "#FFE000": "stairs;steps",
19
+ "#00F5FF": "fan",
20
+ "#FF008F": "step;stair",
21
+ "#1F00FF": "stairway;staircase",
22
+ "#FFD600": "radiator",
23
+ }
24
+
25
+ OBJECTS = {
26
+ "#CC05FF": "bed",
27
+ "#FF0633": "painting;picture",
28
+ "#DCDCDC": "mirror",
29
+ "#00FF14": "box",
30
+ "#FF0000": "flower",
31
+ "#FFA300": "book",
32
+ "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
33
+ "#F500FF": "pot;flowerpot",
34
+ "#00FFCC": "vase",
35
+ "#29FF00": "tray",
36
+ "#8FFF00": "poster;posting;placard;notice;bill;card",
37
+ "#5CFF00": "basket;handbasket",
38
+ "#00ADFF": "screen;door;screen",
39
+ }
40
+
41
+
42
+ SITTING = {
43
+ "#0B66FF": "sofa;couch;lounge",
44
+ "#CC4603": "chair",
45
+ "#07FFE0": "seat",
46
+ "#08FFD6": "armchair",
47
+ "#FFC207": "cushion",
48
+ "#00EBFF": "pillow",
49
+ "#00D6FF": "stool",
50
+ "#1400FF": "blanket;cover",
51
+ "#0A00FF": "swivel;chair",
52
+ "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
53
+ }
54
+
55
+ LIGHTING = {
56
+ "#E0FF08": "lamp",
57
+ "#FFAD00": "light;light;source",
58
+ "#001FFF": "chandelier;pendant;pendent",
59
+ }
60
+
61
+ TABLES = {
62
+ "#FF0652": "table",
63
+ "#0AFF47": "desk",
64
+ }
65
+
66
+ CLOSETS = {
67
+ "#E005FF": "cabinet",
68
+ "#FF0747": "shelf",
69
+ "#07FFFF": "wardrobe;closet;press",
70
+ "#0633FF": "chest;of;drawers;chest;bureau;dresser",
71
+ "#0000FF": "case;display;case;showcase;vitrine",
72
+ }
73
+
74
+
75
+ BATHROOM = {
76
+ "#6608FF": "bathtub;bathing;tub;bath;tub",
77
+ "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
78
+ "#0085FF": "shower",
79
+ "#FF0066": "towel",
80
+ }
81
+
82
+ WINDOWS = {
83
+ "#FF3307": "curtain;drape;drapery;mantle;pall",
84
+ "#E6E6E6": "windowpane;window",
85
+ "#00FF3D": "awning;sunshade;sunblind",
86
+ "#003DFF": "blind;screen",
87
+ }
88
+
89
+ FLOOR = {
90
+ "#FF095C": "rug;carpet;carpeting",
91
+ "#503232": "floor;flooring",
92
+ }
93
+
94
+ INTERIOR = {
95
+ "#787878": "wall",
96
+ "#787850": "ceiling",
97
+ "#08FF33": "door;double;door",
98
+ }
99
+
100
+ KITCHEN = {
101
+ "#00FF29": "kitchen;island",
102
+ "#14FF00": "refrigerator;icebox",
103
+ "#00A3FF": "sink",
104
+ "#EB0CFF": "counter",
105
+ "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
106
+ "#FF00EB": "microwave;microwave;oven",
107
+ "#47FF00": "oven",
108
+ "#66FF00": "clock",
109
+ "#00FFB8": "plate",
110
+ "#19C2C2": "glass;drinking;glass",
111
+ "#00FF99": "bar",
112
+ "#00FF0A": "bottle",
113
+ "#FF7000": "buffet;counter;sideboard",
114
+ "#B800FF": "washer;automatic;washer;washing;machine",
115
+ "#00FF70": "coffee;table;cocktail;table",
116
+ "#008FFF": "countertop",
117
+ "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
118
+ }
119
+
120
+ LIVINGROOM = {
121
+ "#FA0A0F": "fireplace;hearth;open;fireplace",
122
+ "#FF4700": "pool;table;billiard;table;snooker;table",
123
+ }
124
+
125
+ OFFICE = {
126
+ "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
127
+ "#00FFF5": "bookcase",
128
+ "#0633FF": "chest;of;drawers;chest;bureau;dresser",
129
+ "#005CFF": "monitor;monitoring;device",
130
+ }
131
+
132
+
133
+ COLOR_MAPPING_CATEGORY_ = {
134
+ 'keep background': {'#FFFFFF': 'background'},
135
+ 'trivia': TRIVIA,
136
+ 'objects': OBJECTS,
137
+ 'sitting': SITTING,
138
+ 'lighting': LIGHTING,
139
+ 'tables': TABLES,
140
+ 'closets': CLOSETS,
141
+ 'bathroom': BATHROOM,
142
+ 'windows': WINDOWS,
143
+ 'floor': FLOOR,
144
+ 'interior': INTERIOR,
145
+ 'kitchen': KITCHEN,
146
+ 'livingroom': LIVINGROOM,
147
+ 'office': OFFICE}
148
+
149
+
150
+ COLOR_MAPPING_ = {
151
+ '#FFFFFF': 'background',
152
+ "#787878": "wall",
153
+ "#B47878": "building;edifice",
154
+ "#06E6E6": "sky",
155
+ "#503232": "floor;flooring",
156
+ "#04C803": "tree",
157
+ "#787850": "ceiling",
158
+ "#8C8C8C": "road;route",
159
+ "#CC05FF": "bed",
160
+ "#E6E6E6": "windowpane;window",
161
+ "#04FA07": "grass",
162
+ "#E005FF": "cabinet",
163
+ "#EBFF07": "sidewalk;pavement",
164
+ "#96053D": "person;individual;someone;somebody;mortal;soul",
165
+ "#787846": "earth;ground",
166
+ "#08FF33": "door;double;door",
167
+ "#FF0652": "table",
168
+ "#8FFF8C": "mountain;mount",
169
+ "#CCFF04": "plant;flora;plant;life",
170
+ "#FF3307": "curtain;drape;drapery;mantle;pall",
171
+ "#CC4603": "chair",
172
+ "#0066C8": "car;auto;automobile;machine;motorcar",
173
+ "#3DE6FA": "water",
174
+ "#FF0633": "painting;picture",
175
+ "#0B66FF": "sofa;couch;lounge",
176
+ "#FF0747": "shelf",
177
+ "#FF09E0": "house",
178
+ "#0907E6": "sea",
179
+ "#DCDCDC": "mirror",
180
+ "#FF095C": "rug;carpet;carpeting",
181
+ "#7009FF": "field",
182
+ "#08FFD6": "armchair",
183
+ "#07FFE0": "seat",
184
+ "#FFB806": "fence;fencing",
185
+ "#0AFF47": "desk",
186
+ "#FF290A": "rock;stone",
187
+ "#07FFFF": "wardrobe;closet;press",
188
+ "#E0FF08": "lamp",
189
+ "#6608FF": "bathtub;bathing;tub;bath;tub",
190
+ "#FF3D06": "railing;rail",
191
+ "#FFC207": "cushion",
192
+ "#FF7A08": "base;pedestal;stand",
193
+ "#00FF14": "box",
194
+ "#FF0829": "column;pillar",
195
+ "#FF0599": "signboard;sign",
196
+ "#0633FF": "chest;of;drawers;chest;bureau;dresser",
197
+ "#EB0CFF": "counter",
198
+ "#A09614": "sand",
199
+ "#00A3FF": "sink",
200
+ "#8C8C8C": "skyscraper",
201
+ "#FA0A0F": "fireplace;hearth;open;fireplace",
202
+ "#14FF00": "refrigerator;icebox",
203
+ "#1FFF00": "grandstand;covered;stand",
204
+ "#FF1F00": "path",
205
+ "#FFE000": "stairs;steps",
206
+ "#99FF00": "runway",
207
+ "#0000FF": "case;display;case;showcase;vitrine",
208
+ "#FF4700": "pool;table;billiard;table;snooker;table",
209
+ "#00EBFF": "pillow",
210
+ "#00ADFF": "screen;door;screen",
211
+ "#1F00FF": "stairway;staircase",
212
+ "#0BC8C8": "river",
213
+ "#FF5200": "bridge;span",
214
+ "#00FFF5": "bookcase",
215
+ "#003DFF": "blind;screen",
216
+ "#00FF70": "coffee;table;cocktail;table",
217
+ "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
218
+ "#FF0000": "flower",
219
+ "#FFA300": "book",
220
+ "#FF6600": "hill",
221
+ "#C2FF00": "bench",
222
+ "#008FFF": "countertop",
223
+ "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
224
+ "#0052FF": "palm;palm;tree",
225
+ "#00FF29": "kitchen;island",
226
+ "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
227
+ "#0A00FF": "swivel;chair",
228
+ "#ADFF00": "boat",
229
+ "#00FF99": "bar",
230
+ "#FF5C00": "arcade;machine",
231
+ "#FF00FF": "hovel;hut;hutch;shack;shanty",
232
+ "#FF00F5": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
233
+ "#FF0066": "towel",
234
+ "#FFAD00": "light;light;source",
235
+ "#FF0014": "truck;motortruck",
236
+ "#FFB8B8": "tower",
237
+ "#001FFF": "chandelier;pendant;pendent",
238
+ "#00FF3D": "awning;sunshade;sunblind",
239
+ "#0047FF": "streetlight;street;lamp",
240
+ "#FF00CC": "booth;cubicle;stall;kiosk",
241
+ "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
242
+ "#00FF52": "airplane;aeroplane;plane",
243
+ "#000AFF": "dirt;track",
244
+ "#0070FF": "apparel;wearing;apparel;dress;clothes",
245
+ "#3300FF": "pole",
246
+ "#00C2FF": "land;ground;soil",
247
+ "#007AFF": "bannister;banister;balustrade;balusters;handrail",
248
+ "#00FFA3": "escalator;moving;staircase;moving;stairway",
249
+ "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
250
+ "#00FF0A": "bottle",
251
+ "#FF7000": "buffet;counter;sideboard",
252
+ "#8FFF00": "poster;posting;placard;notice;bill;card",
253
+ "#5200FF": "stage",
254
+ "#A3FF00": "van",
255
+ "#FFEB00": "ship",
256
+ "#08B8AA": "fountain",
257
+ "#8500FF": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
258
+ "#00FF5C": "canopy",
259
+ "#B800FF": "washer;automatic;washer;washing;machine",
260
+ "#FF001F": "plaything;toy",
261
+ "#00B8FF": "swimming;pool;swimming;bath;natatorium",
262
+ "#00D6FF": "stool",
263
+ "#FF0070": "barrel;cask",
264
+ "#5CFF00": "basket;handbasket",
265
+ "#00E0FF": "waterfall;falls",
266
+ "#70E0FF": "tent;collapsible;shelter",
267
+ "#46B8A0": "bag",
268
+ "#A300FF": "minibike;motorbike",
269
+ "#9900FF": "cradle",
270
+ "#47FF00": "oven",
271
+ "#FF00A3": "ball",
272
+ "#FFCC00": "food;solid;food",
273
+ "#FF008F": "step;stair",
274
+ "#00FFEB": "tank;storage;tank",
275
+ "#85FF00": "trade;name;brand;name;brand;marque",
276
+ "#FF00EB": "microwave;microwave;oven",
277
+ "#F500FF": "pot;flowerpot",
278
+ "#FF007A": "animal;animate;being;beast;brute;creature;fauna",
279
+ "#FFF500": "bicycle;bike;wheel;cycle",
280
+ "#0ABED4": "lake",
281
+ "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
282
+ "#00CCFF": "screen;silver;screen;projection;screen",
283
+ "#1400FF": "blanket;cover",
284
+ "#FFFF00": "sculpture",
285
+ "#0099FF": "hood;exhaust;hood",
286
+ "#0029FF": "sconce",
287
+ "#00FFCC": "vase",
288
+ "#2900FF": "traffic;light;traffic;signal;stoplight",
289
+ "#29FF00": "tray",
290
+ "#AD00FF": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
291
+ "#00F5FF": "fan",
292
+ "#4700FF": "pier;wharf;wharfage;dock",
293
+ "#7A00FF": "crt;screen",
294
+ "#00FFB8": "plate",
295
+ "#005CFF": "monitor;monitoring;device",
296
+ "#B8FF00": "bulletin;board;notice;board",
297
+ "#0085FF": "shower",
298
+ "#FFD600": "radiator",
299
+ "#19C2C2": "glass;drinking;glass",
300
+ "#66FF00": "clock",
301
+ "#5C00FF": "flag",
302
+ }
303
+
304
+ def ade_palette() -> List[List[int]]:
305
+ """ADE20K palette that maps each class to RGB values."""
306
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
307
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
308
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
309
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
310
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
311
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
312
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
313
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
314
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
315
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
316
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
317
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
318
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
319
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
320
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
321
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
322
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
323
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
324
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
325
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
326
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
327
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
328
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
329
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
330
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
331
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
332
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
333
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
334
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
335
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
336
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
337
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
338
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
339
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
340
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
341
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
342
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
343
+ [102, 255, 0], [92, 0, 255]]
config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File with configs"""
2
+ from palette import COLOR_MAPPING_, COLOR_MAPPING
3
+
4
+ HEIGHT = 512
5
+ WIDTH = 512
6
+
7
+ def to_rgb(color: str) -> tuple:
8
+ """Convert hex color to rgb.
9
+ Args:
10
+ color (str): hex color
11
+ Returns:
12
+ tuple: rgb color
13
+ """
14
+ return tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
15
+
16
+ COLOR_NAMES = list(COLOR_MAPPING.keys())
17
+ COLOR_RGB = [to_rgb(k) for k in COLOR_MAPPING_.keys()] + [(0, 0, 0), (255, 255, 255)]
18
+ INVERSE_COLORS = {v: to_rgb(k) for k, v in COLOR_MAPPING_.items()}
19
+ COLOR_MAPPING_RGB = {to_rgb(k): v for k, v in COLOR_MAPPING_.items()}
20
+
21
+ def map_colors(color: str) -> str:
22
+ """Map color to hex value.
23
+ Args:
24
+ color (str): color name
25
+ Returns:
26
+ str: hex value
27
+ """
28
+ return COLOR_MAPPING[color]
29
+
30
+ def map_colors_rgb(color: tuple) -> str:
31
+ return COLOR_MAPPING_RGB[color]
32
+
33
+
34
+ POS_PROMPT = "tree, sky, cloud, scenery, outdoors, grass, flowers, sunlight, beautiful, ultra detailed beautiful landscape, architectural renderings vegetation, high res, best high quality landscape, outdoor lighting, sunshine, 4k, 8k, realistic"
35
+ NEG_PROMPT= "lowres, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, mutated hands and fingers, out of frame"
empty_room.jpg ADDED
explanation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def make_inpainting_explanation():
4
+ with st.expander("Explanation inpainting", expanded=False):
5
+ st.write("In the inpainting mode, you can draw regions on the input image that you want to regenerate. "
6
+ "This can be useful to remove unwanted objects from the image or to improve the consistency of the image."
7
+ )
8
+ st.image("content/inpainting_sidebar.png", caption="Image before inpainting, note the ornaments on the wall", width=500)
9
+ st.write("You can find drawing options in the sidebar. There are two modes: freedraw and polygon. Freedraw allows the user to draw with a pencil of a certain width. "
10
+ "Polygon allows the user to draw a polygon by clicking on the image to add a point. The polygon is closed by right clicking.")
11
+
12
+ st.write("### Example inpainting")
13
+ st.write("In the example below, the ornaments on the wall are removed. The inpainting is done by drawing a mask on the image.")
14
+ st.image("content/inpainting_before.jpg", caption="Image before inpainting, note the ornaments on the wall")
15
+ st.image("content/inpainting_after.png", caption="Image before inpainting, note the ornaments on the wall")
16
+
17
+ def make_regeneration_explanation():
18
+ with st.expander("Explanation object regeneration"):
19
+ st.write("In this object regeneration mode, the model calculates which objects occur in the image. "
20
+ "The user can then select which objects can be regenerated by the controlnet model by adding them in the multiselect box. "
21
+ "All the object classes that are not selected will remain the same as in the original image."
22
+ )
23
+ st.write("### Example object regeneration")
24
+ st.write("In the example below, the room consists of various objects such as wall, ceiling, floor, lamp, bed, ... "
25
+ "In the multiselect box, all the objects except for 'lamp', 'bed and 'table' are selected to be regenerated. "
26
+ )
27
+ st.image("content/regen_example.png", caption="Room where all concepts except for 'bed', 'lamp', 'table' are regenerated")
28
+
29
+ def make_segmentation_explanation():
30
+ with st.expander("Segmentation mode", expanded=False):
31
+ st.write("In the segmentation mode, the user can use his imagination and the paint brush to place concepts in the image. "
32
+ "In the left sidebar, you can first find the high level category of the concept you want to add, such as 'lighting', 'floor', .. "
33
+ "After selecting the category, you can select the specific concept you want to add in the 'Choose a color' dropdown. "
34
+ "This will change the color of the paint brush, which you can then use to draw on the input image. "
35
+ "The model will then regenerate the image with the concepts you have drawn and leave the rest of the image unchanged. "
36
+ )
37
+ st.image("content/sidebar segmentation.png", caption="Sidebar with segmentation options", width=300)
38
+ st.write("You can choose the freedraw mode which gives you a pencil of a certain (chosen) width or the polygon mode. With the polygon mode you can click to add a point to the polygon and close the polygon by right clicking. ")
39
+ st.write("Important: "
40
+ "it's not easy to draw a good segmentation mask. This is because you need to keep in mind the perspective of the room and the exact "
41
+ "shape of the object you want to draw within this perspective. Controlnet will follow your segmentation mask pretty well, so "
42
+ "a non-natural object shape will sometimes result in weird outputs. However, give it a try and see what you can do! "
43
+ )
44
+ st.image("content/segmentation window.png", caption="Example of a segmentation mask drawn on the input image to add a window to the room")
45
+ st.write("Tip: ")
46
+ st.write("In the concepts dropdown, you can select 'keep background' (which is a white color). Everything drawn in this color will use "
47
+ "the original underlying segmentation mask. This can be useful to help with generating other objects, since you give the model a some "
48
+ "freedom to generate outside the object borders."
49
+ )
50
+ st.image("content/keep background 1.png", caption="Image with a poster drawn on the wall.")
51
+ st.image("content/keep background 2.png", caption="Image with a poster drawn on the wall surrounded by 'keep background'.")
helpers.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from scipy.signal import fftconvolve
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ def flush():
8
+ gc.collect()
9
+ torch.cuda.empty_cache()
10
+
11
+
12
+
13
+ def convolution(mask: Image.Image, size=9) -> Image:
14
+ """Method to blur the mask
15
+ Args:
16
+ mask (Image): masking image
17
+ size (int, optional): size of the blur. Defaults to 9.
18
+ Returns:
19
+ Image: blurred mask
20
+ """
21
+ mask = np.array(mask.convert("L"))
22
+ conv = np.ones((size, size)) / size**2
23
+ mask_blended = fftconvolve(mask, conv, 'same')
24
+ mask_blended = mask_blended.astype(np.uint8).copy()
25
+
26
+ border = size
27
+
28
+ # replace borders with original values
29
+ mask_blended[:border, :] = mask[:border, :]
30
+ mask_blended[-border:, :] = mask[-border:, :]
31
+ mask_blended[:, :border] = mask[:, :border]
32
+ mask_blended[:, -border:] = mask[:, -border:]
33
+
34
+ return Image.fromarray(mask_blended).convert("L")
35
+
36
+
37
+ def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image:
38
+ """Method to postprocess the inpainted image
39
+ Args:
40
+ inpainted (Image): inpainted image
41
+ image (Image): original image
42
+ mask (Image): mask
43
+ Returns:
44
+ Image: inpainted image
45
+ """
46
+ final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask)
47
+ return final_inpainted.convert("RGB")
models.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains methods for inference and image generation."""
2
+ import logging
3
+ from typing import List, Tuple, Dict
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import gc
8
+ import time
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL import ImageFilter
12
+
13
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
14
+
15
+ from config import WIDTH, HEIGHT
16
+ from palette import ade_palette
17
+ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
18
+ from helpers import flush, postprocess_image_masking, convolution
19
+ from pipelines import ControlNetPipeline, SDPipeline, get_inpainting_pipeline, get_controlnet
20
+
21
+ LOGGING = logging.getLogger(__name__)
22
+
23
+
24
+ @torch.inference_mode()
25
+ def make_image_controlnet(image: np.ndarray,
26
+ mask_image: np.ndarray,
27
+ controlnet_conditioning_image: np.ndarray,
28
+ positive_prompt: str, negative_prompt: str,
29
+ seed: int = 2356132) -> List[Image.Image]:
30
+ """Method to make image using controlnet
31
+ Args:
32
+ image (np.ndarray): input image
33
+ mask_image (np.ndarray): mask image
34
+ controlnet_conditioning_image (np.ndarray): conditioning image
35
+ positive_prompt (str): positive prompt string
36
+ negative_prompt (str): negative prompt string
37
+ seed (int, optional): seed. Defaults to 2356132.
38
+ Returns:
39
+ List[Image.Image]: list of generated images
40
+ """
41
+
42
+ pipe = get_controlnet()
43
+ flush()
44
+
45
+ image = Image.fromarray(image).convert("RGB")
46
+ controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB")#.filter(ImageFilter.GaussianBlur(radius = 9))
47
+ mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
48
+ mask_image_postproc = convolution(mask_image)
49
+
50
+
51
+ st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
52
+ generated_image = pipe(
53
+ prompt=positive_prompt,
54
+ negative_prompt=negative_prompt,
55
+ num_inference_steps=50,
56
+ strength=1.00,
57
+ guidance_scale=7.0,
58
+ generator=[torch.Generator(device="cuda").manual_seed(seed)],
59
+ image=image,
60
+ mask_image=mask_image,
61
+ controlnet_conditioning_image=controlnet_conditioning_image,
62
+ ).images[0]
63
+ generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
64
+
65
+ return generated_image
66
+
67
+
68
+ @torch.inference_mode()
69
+ def make_inpainting(positive_prompt: str,
70
+ image: Image,
71
+ mask_image: np.ndarray,
72
+ negative_prompt: str = "") -> List[Image.Image]:
73
+ """Method to make inpainting
74
+ Args:
75
+ positive_prompt (str): positive prompt string
76
+ image (Image): input image
77
+ mask_image (np.ndarray): mask image
78
+ negative_prompt (str, optional): negative prompt string. Defaults to "".
79
+ Returns:
80
+ List[Image.Image]: list of generated images
81
+ """
82
+ pipe = get_inpainting_pipeline()
83
+ mask_image = Image.fromarray((mask_image * 255).astype(np.uint8))
84
+ mask_image_postproc = convolution(mask_image)
85
+
86
+ flush()
87
+ st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
88
+ generated_image = pipe(image=image,
89
+ mask_image=mask_image,
90
+ prompt=positive_prompt,
91
+ negative_prompt=negative_prompt,
92
+ num_inference_steps=50,
93
+ height=HEIGHT,
94
+ width=WIDTH,
95
+ ).images[0]
96
+ generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
97
+
98
+ return generated_image
palette.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains color information"""
2
+ from typing import List, Dict
3
+ from colors import COLOR_MAPPING_, COLOR_MAPPING_CATEGORY_, ade_palette
4
+
5
+
6
+ def convert_hex_to_rgba(hex_code: str) -> str:
7
+ """Convert hex code to rgba.
8
+ Args:
9
+ hex_code (str): hex string
10
+ Returns:
11
+ str: rgba string
12
+ """
13
+ hex_code = hex_code.lstrip('#')
14
+ return "rgba(" + str(int(hex_code[0:2], 16)) + ", " + str(int(hex_code[2:4], 16)) + ", " + str(int(hex_code[4:6], 16)) + ", 1.0)"
15
+
16
+
17
+ def convert_dict_to_rgba(color_dict: Dict) -> Dict:
18
+ """Convert hex code to rgba for all elements in a dictionary.
19
+ Args:
20
+ color_dict (Dict): color dictionary
21
+ Returns:
22
+ Dict: color dictionary with rgba values
23
+ """
24
+ updated_dict = {}
25
+ for k, v in color_dict.items():
26
+ updated_dict[convert_hex_to_rgba(k)] = v
27
+ return updated_dict
28
+
29
+
30
+ def convert_nested_dict_to_rgba(nested_dict):
31
+ updated_dict = {}
32
+ for k, v in nested_dict.items():
33
+ updated_dict[k] = convert_dict_to_rgba(v)
34
+ return updated_dict
35
+
36
+
37
+ COLOR_MAPPING = convert_dict_to_rgba(COLOR_MAPPING_)
38
+ COLOR_MAPPING_CATEGORY = convert_nested_dict_to_rgba(COLOR_MAPPING_CATEGORY_)
pipelines.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Tuple, Dict
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import gc
7
+ import time
8
+ import numpy as np
9
+ from PIL import Image
10
+ from time import perf_counter
11
+ from contextlib import contextmanager
12
+ from scipy.signal import fftconvolve
13
+ from PIL import ImageFilter
14
+
15
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
16
+ from diffusers import StableDiffusionInpaintPipeline
17
+
18
+ from config import WIDTH, HEIGHT
19
+ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
20
+ from helpers import flush
21
+
22
+ LOGGING = logging.getLogger(__name__)
23
+
24
+ class ControlNetPipeline:
25
+ def __init__(self):
26
+ self.in_use = False
27
+ self.controlnet = ControlNetModel.from_pretrained(
28
+ "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
29
+
30
+ self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
31
+ "runwayml/stable-diffusion-inpainting",
32
+ controlnet=self.controlnet,
33
+ safety_checker=None,
34
+ torch_dtype=torch.float16
35
+ )
36
+
37
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
38
+ self.pipe.enable_xformers_memory_efficient_attention()
39
+ self.pipe = self.pipe.to("cuda")
40
+
41
+ self.waiting_queue = []
42
+ self.count = 0
43
+
44
+ @property
45
+ def queue_size(self):
46
+ return len(self.waiting_queue)
47
+
48
+ def __call__(self, **kwargs):
49
+ self.count += 1
50
+ number = self.count
51
+
52
+ self.waiting_queue.append(number)
53
+
54
+ # wait until the next number in the queue is the current number
55
+ while self.waiting_queue[0] != number:
56
+ print(f"Wait for your turn {number} in queue {self.waiting_queue}")
57
+ time.sleep(0.5)
58
+ pass
59
+
60
+ # it's your turn, so remove the number from the queue
61
+ # and call the function
62
+ print("It's the turn of", self.count)
63
+ results = self.pipe(**kwargs)
64
+ self.waiting_queue.pop(0)
65
+ flush()
66
+ return results
67
+
68
+ class SDPipeline:
69
+ def __init__(self):
70
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
71
+ "stabilityai/stable-diffusion-2-inpainting",
72
+ torch_dtype=torch.float16,
73
+ safety_checker=None,
74
+ )
75
+
76
+ self.pipe.enable_xformers_memory_efficient_attention()
77
+ # self.pipe = self.pipe.to("cuda")
78
+
79
+ self.waiting_queue = []
80
+ self.count = 0
81
+
82
+ @property
83
+ def queue_size(self):
84
+ return len(self.waiting_queue)
85
+
86
+ def __call__(self, **kwargs):
87
+ self.count += 1
88
+ number = self.count
89
+
90
+ self.waiting_queue.append(number)
91
+
92
+ # wait until the next number in the queue is the current number
93
+ while self.waiting_queue[0] != number:
94
+ print(f"Wait for your turn {number} in queue {self.waiting_queue}")
95
+ time.sleep(0.5)
96
+ pass
97
+
98
+ # it's your turn, so remove the number from the queue
99
+ # and call the function
100
+ print("It's the turn of", self.count)
101
+ results = self.pipe(**kwargs)
102
+ self.waiting_queue.pop(0)
103
+ flush()
104
+ return results
105
+
106
+
107
+
108
+ @st.cache_resource(max_entries=5)
109
+ def get_controlnet():
110
+ """Method to load the controlnet model
111
+ Returns:
112
+ ControlNetModel: controlnet model
113
+ """
114
+ pipe = ControlNetPipeline()
115
+ return pipe
116
+
117
+
118
+
119
+ @st.cache_resource(max_entries=5)
120
+ def get_inpainting_pipeline():
121
+ """Method to load the inpainting pipeline
122
+ Returns:
123
+ StableDiffusionInpaintPipeline: inpainting pipeline
124
+ """
125
+ pipe = SDPipeline()
126
+ return pipe
preprocessing.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocessing methods"""
2
+ import logging
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageFilter
7
+ import streamlit as st
8
+
9
+ from config import COLOR_RGB, WIDTH, HEIGHT
10
+ # from enhance_config import ENHANCE_SETTINGS
11
+
12
+ LOGGING = logging.getLogger(__name__)
13
+
14
+
15
+ def preprocess_seg_mask(canvas_seg, real_seg: Image.Image = None) -> Tuple[np.ndarray, np.ndarray]:
16
+ """Preprocess the segmentation mask.
17
+ Args:
18
+ canvas_seg: segmentation canvas
19
+ real_seg (Image.Image, optional): segmentation mask. Defaults to None.
20
+ Returns:
21
+ Tuple[np.ndarray, np.ndarray]: segmentation mask, segmentation mask with overlay
22
+ """
23
+ # get unique colors in the segmentation
24
+ image_seg = canvas_seg.image_data.copy()[:, :, :3]
25
+
26
+ # average the colors of the segmentation masks
27
+ average_color = np.mean(image_seg, axis=(2))
28
+ mask = average_color[:, :] > 0
29
+ if mask.sum() > 0:
30
+ mask = mask * 1
31
+
32
+ unique_colors = np.unique(image_seg.reshape(-1, image_seg.shape[-1]), axis=0)
33
+ unique_colors = [tuple(color) for color in unique_colors]
34
+
35
+ unique_colors = [color for color in unique_colors if np.sum(
36
+ np.all(image_seg == color, axis=-1)) > 100]
37
+
38
+ unique_colors_exact = [color for color in unique_colors if color in COLOR_RGB]
39
+
40
+ if real_seg is not None:
41
+ overlay_seg = np.array(real_seg)
42
+
43
+ unique_colors = np.unique(overlay_seg.reshape(-1, overlay_seg.shape[-1]), axis=0)
44
+ unique_colors = [tuple(color) for color in unique_colors]
45
+
46
+ for color in unique_colors_exact:
47
+ if color != (255, 255, 255) and color != (0, 0, 0):
48
+ overlay_seg[np.all(image_seg == color, axis=-1)] = color
49
+ image_seg = overlay_seg
50
+
51
+ return mask, image_seg
52
+
53
+
54
+ def get_mask(image_mask: np.ndarray) -> np.ndarray:
55
+ """Get the mask from the segmentation mask.
56
+ Args:
57
+ image_mask (np.ndarray): segmentation mask
58
+ Returns:
59
+ np.ndarray: mask
60
+ """
61
+ # average the colors of the segmentation masks
62
+ average_color = np.mean(image_mask, axis=(2))
63
+ mask = average_color[:, :] > 0
64
+ if mask.sum() > 0:
65
+ mask = mask * 1
66
+ return mask
67
+
68
+
69
+ def get_image() -> np.ndarray:
70
+ """Get the image from the session state.
71
+ Returns:
72
+ np.ndarray: image
73
+ """
74
+ if 'initial_image' in st.session_state and st.session_state['initial_image'] is not None:
75
+ initial_image = st.session_state['initial_image']
76
+ if isinstance(initial_image, Image.Image):
77
+ return np.array(initial_image.resize((WIDTH, HEIGHT)))
78
+ else:
79
+ return np.array(Image.fromarray(initial_image).resize((WIDTH, HEIGHT)))
80
+ else:
81
+ return None
82
+
83
+
84
+ # def make_enhance_config(segmentation, objects=None):
85
+ """Make the enhance config for the segmentation image.
86
+ """
87
+ info = ENHANCE_SETTINGS[objects]
88
+
89
+ segmentation = np.array(segmentation)
90
+
91
+ if 'replace' in info:
92
+ replace_color = info['replace']
93
+ mask = np.zeros(segmentation.shape)
94
+ for color in info['colors']:
95
+ mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
96
+ segmentation[np.all(segmentation == color, axis=-1)] = replace_color
97
+
98
+ if info['inverse'] is False:
99
+ mask = np.zeros(segmentation.shape)
100
+ for color in info['colors']:
101
+ mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
102
+ else:
103
+ mask = np.ones(segmentation.shape)
104
+ for color in info['colors']:
105
+ mask[np.all(segmentation == color, axis=-1)] = [0, 0, 0]
106
+
107
+ st.session_state['positive_prompt'] = info['positive_prompt']
108
+ st.session_state['negative_prompt'] = info['negative_prompt']
109
+
110
+ if info['inpainting'] is True:
111
+ mask = mask.astype(np.uint8)
112
+ mask = Image.fromarray(mask)
113
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=13))
114
+ mask = mask.filter(ImageFilter.MaxFilter(size=9))
115
+ mask = np.array(mask)
116
+
117
+ mask[mask < 0.1] = 0
118
+ mask[mask >= 0.1] = 1
119
+ mask = mask.astype(np.uint8)
120
+
121
+ conditioning = dict(
122
+ mask_image=mask,
123
+ positive_prompt=info['positive_prompt'],
124
+ negative_prompt=info['negative_prompt'],
125
+ )
126
+ else:
127
+ conditioning = dict(
128
+ mask_image=mask,
129
+ controlnet_conditioning_image=segmentation,
130
+ positive_prompt=info['positive_prompt'],
131
+ negative_prompt=info['negative_prompt'],
132
+ strength=info['strength']
133
+ )
134
+ return conditioning, info['inpainting']
requirements.txt CHANGED
@@ -1,14 +1,14 @@
1
- streamlit
2
- streamlit-drawable-canvas
3
- diffusers
4
- xformers
5
- transformers
6
- torchvision
7
  git+https://github.com/huggingface/accelerate.git
8
- opencv-python-headless
9
- scipy
10
  python-docx
11
- extra-streamlit-components
12
  triton
13
- altair
14
  gradio
 
1
+ streamlit==1.20.0
2
+ streamlit-drawable-canvas==0.9.0
3
+ diffusers==0.15.0
4
+ xformers==0.0.16
5
+ transformers==4.28.0
6
+ torchvision==0.14.1
7
  git+https://github.com/huggingface/accelerate.git
8
+ opencv-python-headless==4.7.0.72
9
+ scipy==1.10.0
10
  python-docx
11
+ extra-streamlit-components==0.1.56
12
  triton
13
+ altair==4.1.0
14
  gradio
segmentation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Tuple, Dict
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import gc
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
11
+
12
+ from palette import ade_palette
13
+
14
+ LOGGING = logging.getLogger(__name__)
15
+
16
+
17
+ def flush():
18
+ gc.collect()
19
+ torch.cuda.empty_cache()
20
+
21
+ @st.cache_resource(max_entries=5)
22
+ def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
23
+ """Method to load the segmentation pipeline
24
+ Returns:
25
+ Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
26
+ """
27
+ image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
28
+ image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
29
+ "openmmlab/upernet-convnext-small")
30
+ return image_processor, image_segmentor
31
+
32
+
33
+ @torch.inference_mode()
34
+ @torch.autocast('cuda')
35
+ def segment_image(image: Image) -> Image:
36
+ """Method to segment image
37
+ Args:
38
+ image (Image): input image
39
+ Returns:
40
+ Image: segmented image
41
+ """
42
+ image_processor, image_segmentor = get_segmentation_pipeline()
43
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
44
+ with torch.no_grad():
45
+ outputs = image_segmentor(pixel_values)
46
+
47
+ seg = image_processor.post_process_semantic_segmentation(
48
+ outputs, target_sizes=[image.size[::-1]])[0]
49
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
50
+ palette = np.array(ade_palette())
51
+ for label, color in enumerate(palette):
52
+ color_seg[seg == label, :] = color
53
+ color_seg = color_seg.astype(np.uint8)
54
+ seg_image = Image.fromarray(color_seg).convert('RGB')
55
+ return seg_image
stable_diffusion_controlnet_inpaint_img2img.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the StableDiffusionControlNetInpaintImg2ImgPipeline class from the
2
+ community pipelines from the diffusers library of HuggingFace.
3
+ """
4
+ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
5
+
6
+ import inspect
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ PIL_INTERPOLATION,
20
+ is_accelerate_available,
21
+ is_accelerate_version,
22
+ randn_tensor,
23
+ replace_example_docstring,
24
+ )
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```py
32
+ >>> import numpy as np
33
+ >>> import torch
34
+ >>> from PIL import Image
35
+ >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
36
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
37
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
38
+ >>> from diffusers.utils import load_image
39
+ >>> def ade_palette():
40
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
41
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
42
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
43
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
44
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
45
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
46
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
47
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
48
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
49
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
51
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
52
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
53
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
54
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
55
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
56
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
57
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
58
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
59
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
60
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
61
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
62
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
63
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
64
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
65
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
66
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
67
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
68
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
69
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
70
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
71
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
72
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
73
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
74
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
75
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
76
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
77
+ [102, 255, 0], [92, 0, 255]]
78
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
+ >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
+ >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
81
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
82
+ )
83
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
84
+ >>> pipe.enable_xformers_memory_efficient_attention()
85
+ >>> pipe.enable_model_cpu_offload()
86
+ >>> def image_to_seg(image):
87
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
88
+ with torch.no_grad():
89
+ outputs = image_segmentor(pixel_values)
90
+ seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
91
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
92
+ palette = np.array(ade_palette())
93
+ for label, color in enumerate(palette):
94
+ color_seg[seg == label, :] = color
95
+ color_seg = color_seg.astype(np.uint8)
96
+ seg_image = Image.fromarray(color_seg)
97
+ return seg_image
98
+ >>> image = load_image(
99
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
100
+ )
101
+ >>> mask_image = load_image(
102
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
103
+ )
104
+ >>> controlnet_conditioning_image = image_to_seg(image)
105
+ >>> image = pipe(
106
+ "Face of a yellow cat, high resolution, sitting on a park bench",
107
+ image,
108
+ mask_image,
109
+ controlnet_conditioning_image,
110
+ num_inference_steps=20,
111
+ ).images[0]
112
+ >>> image.save("out.png")
113
+ ```
114
+ """
115
+
116
+
117
+ def prepare_image(image):
118
+ if isinstance(image, torch.Tensor):
119
+ # Batch single image
120
+ if image.ndim == 3:
121
+ image = image.unsqueeze(0)
122
+
123
+ image = image.to(dtype=torch.float32)
124
+ else:
125
+ # preprocess image
126
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
127
+ image = [image]
128
+
129
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
130
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
131
+ image = np.concatenate(image, axis=0)
132
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
133
+ image = np.concatenate([i[None, :] for i in image], axis=0)
134
+
135
+ image = image.transpose(0, 3, 1, 2)
136
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
137
+
138
+ return image
139
+
140
+
141
+ def prepare_mask_image(mask_image):
142
+ if isinstance(mask_image, torch.Tensor):
143
+ if mask_image.ndim == 2:
144
+ # Batch and add channel dim for single mask
145
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
146
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
147
+ # Single mask, the 0'th dimension is considered to be
148
+ # the existing batch size of 1
149
+ mask_image = mask_image.unsqueeze(0)
150
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
151
+ # Batch of mask, the 0'th dimension is considered to be
152
+ # the batching dimension
153
+ mask_image = mask_image.unsqueeze(1)
154
+
155
+ # Binarize mask
156
+ mask_image[mask_image < 0.5] = 0
157
+ mask_image[mask_image >= 0.5] = 1
158
+ else:
159
+ # preprocess mask
160
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
161
+ mask_image = [mask_image]
162
+
163
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
164
+ mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
165
+ mask_image = mask_image.astype(np.float32) / 255.0
166
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
167
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
168
+
169
+ mask_image[mask_image < 0.5] = 0
170
+ mask_image[mask_image >= 0.5] = 1
171
+ mask_image = torch.from_numpy(mask_image)
172
+
173
+ return mask_image
174
+
175
+
176
+ def prepare_controlnet_conditioning_image(
177
+ controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
178
+ ):
179
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
180
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
181
+ controlnet_conditioning_image = [controlnet_conditioning_image]
182
+
183
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
184
+ controlnet_conditioning_image = [
185
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
186
+ for i in controlnet_conditioning_image
187
+ ]
188
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
189
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
190
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
191
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
192
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
193
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
194
+
195
+ image_batch_size = controlnet_conditioning_image.shape[0]
196
+
197
+ if image_batch_size == 1:
198
+ repeat_by = batch_size
199
+ else:
200
+ # image batch size is the same as prompt batch size
201
+ repeat_by = num_images_per_prompt
202
+
203
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
204
+
205
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
206
+
207
+ return controlnet_conditioning_image
208
+
209
+
210
+ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
211
+ """
212
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
213
+ """
214
+
215
+ _optional_components = ["safety_checker", "feature_extractor"]
216
+
217
+ def __init__(
218
+ self,
219
+ vae: AutoencoderKL,
220
+ text_encoder: CLIPTextModel,
221
+ tokenizer: CLIPTokenizer,
222
+ unet: UNet2DConditionModel,
223
+ controlnet: ControlNetModel,
224
+ scheduler: KarrasDiffusionSchedulers,
225
+ safety_checker: StableDiffusionSafetyChecker,
226
+ feature_extractor: CLIPFeatureExtractor,
227
+ requires_safety_checker: bool = True,
228
+ ):
229
+ super().__init__()
230
+
231
+ if safety_checker is None and requires_safety_checker:
232
+ logger.warning(
233
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
234
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
235
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
236
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
237
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
238
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
239
+ )
240
+
241
+ if safety_checker is not None and feature_extractor is None:
242
+ raise ValueError(
243
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
244
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
245
+ )
246
+
247
+ self.register_modules(
248
+ vae=vae,
249
+ text_encoder=text_encoder,
250
+ tokenizer=tokenizer,
251
+ unet=unet,
252
+ controlnet=controlnet,
253
+ scheduler=scheduler,
254
+ safety_checker=safety_checker,
255
+ feature_extractor=feature_extractor,
256
+ )
257
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
258
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
259
+
260
+ def enable_vae_slicing(self):
261
+ r"""
262
+ Enable sliced VAE decoding.
263
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
264
+ steps. This is useful to save some memory and allow larger batch sizes.
265
+ """
266
+ self.vae.enable_slicing()
267
+
268
+ def disable_vae_slicing(self):
269
+ r"""
270
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
271
+ computing decoding in one step.
272
+ """
273
+ self.vae.disable_slicing()
274
+
275
+ def enable_sequential_cpu_offload(self, gpu_id=0):
276
+ r"""
277
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
278
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
279
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
280
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
281
+ `enable_model_cpu_offload`, but performance is lower.
282
+ """
283
+ if is_accelerate_available():
284
+ from accelerate import cpu_offload
285
+ else:
286
+ raise ImportError("Please install accelerate via `pip install accelerate`")
287
+
288
+ device = torch.device(f"cuda:{gpu_id}")
289
+
290
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
291
+ cpu_offload(cpu_offloaded_model, device)
292
+
293
+ if self.safety_checker is not None:
294
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
295
+
296
+ def enable_model_cpu_offload(self, gpu_id=0):
297
+ r"""
298
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
299
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
300
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
301
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
302
+ """
303
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
304
+ from accelerate import cpu_offload_with_hook
305
+ else:
306
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
307
+
308
+ device = torch.device(f"cuda:{gpu_id}")
309
+
310
+ hook = None
311
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
312
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
313
+
314
+ if self.safety_checker is not None:
315
+ # the safety checker can offload the vae again
316
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
317
+
318
+ # control net hook has be manually offloaded as it alternates with unet
319
+ cpu_offload_with_hook(self.controlnet, device)
320
+
321
+ # We'll offload the last model manually.
322
+ self.final_offload_hook = hook
323
+
324
+ @property
325
+ def _execution_device(self):
326
+ r"""
327
+ Returns the device on which the pipeline's models will be executed. After calling
328
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
329
+ hooks.
330
+ """
331
+ if not hasattr(self.unet, "_hf_hook"):
332
+ return self.device
333
+ for module in self.unet.modules():
334
+ if (
335
+ hasattr(module, "_hf_hook")
336
+ and hasattr(module._hf_hook, "execution_device")
337
+ and module._hf_hook.execution_device is not None
338
+ ):
339
+ return torch.device(module._hf_hook.execution_device)
340
+ return self.device
341
+
342
+ def _encode_prompt(
343
+ self,
344
+ prompt,
345
+ device,
346
+ num_images_per_prompt,
347
+ do_classifier_free_guidance,
348
+ negative_prompt=None,
349
+ prompt_embeds: Optional[torch.FloatTensor] = None,
350
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
351
+ ):
352
+ r"""
353
+ Encodes the prompt into text encoder hidden states.
354
+ Args:
355
+ prompt (`str` or `List[str]`, *optional*):
356
+ prompt to be encoded
357
+ device: (`torch.device`):
358
+ torch device
359
+ num_images_per_prompt (`int`):
360
+ number of images that should be generated per prompt
361
+ do_classifier_free_guidance (`bool`):
362
+ whether to use classifier free guidance or not
363
+ negative_prompt (`str` or `List[str]`, *optional*):
364
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
365
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
366
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
367
+ prompt_embeds (`torch.FloatTensor`, *optional*):
368
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
+ provided, text embeddings will be generated from `prompt` input argument.
370
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
372
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
373
+ argument.
374
+ """
375
+ if prompt is not None and isinstance(prompt, str):
376
+ batch_size = 1
377
+ elif prompt is not None and isinstance(prompt, list):
378
+ batch_size = len(prompt)
379
+ else:
380
+ batch_size = prompt_embeds.shape[0]
381
+
382
+ if prompt_embeds is None:
383
+ text_inputs = self.tokenizer(
384
+ prompt,
385
+ padding="max_length",
386
+ max_length=self.tokenizer.model_max_length,
387
+ truncation=True,
388
+ return_tensors="pt",
389
+ )
390
+ text_input_ids = text_inputs.input_ids
391
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
392
+
393
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
394
+ text_input_ids, untruncated_ids
395
+ ):
396
+ removed_text = self.tokenizer.batch_decode(
397
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
398
+ )
399
+ logger.warning(
400
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
401
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
402
+ )
403
+
404
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
405
+ attention_mask = text_inputs.attention_mask.to(device)
406
+ else:
407
+ attention_mask = None
408
+
409
+ prompt_embeds = self.text_encoder(
410
+ text_input_ids.to(device),
411
+ attention_mask=attention_mask,
412
+ )
413
+ prompt_embeds = prompt_embeds[0]
414
+
415
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
416
+
417
+ bs_embed, seq_len, _ = prompt_embeds.shape
418
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
419
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
420
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
421
+
422
+ # get unconditional embeddings for classifier free guidance
423
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
424
+ uncond_tokens: List[str]
425
+ if negative_prompt is None:
426
+ uncond_tokens = [""] * batch_size
427
+ elif type(prompt) is not type(negative_prompt):
428
+ raise TypeError(
429
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
430
+ f" {type(prompt)}."
431
+ )
432
+ elif isinstance(negative_prompt, str):
433
+ uncond_tokens = [negative_prompt]
434
+ elif batch_size != len(negative_prompt):
435
+ raise ValueError(
436
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
437
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
438
+ " the batch size of `prompt`."
439
+ )
440
+ else:
441
+ uncond_tokens = negative_prompt
442
+
443
+ max_length = prompt_embeds.shape[1]
444
+ uncond_input = self.tokenizer(
445
+ uncond_tokens,
446
+ padding="max_length",
447
+ max_length=max_length,
448
+ truncation=True,
449
+ return_tensors="pt",
450
+ )
451
+
452
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
453
+ attention_mask = uncond_input.attention_mask.to(device)
454
+ else:
455
+ attention_mask = None
456
+
457
+ negative_prompt_embeds = self.text_encoder(
458
+ uncond_input.input_ids.to(device),
459
+ attention_mask=attention_mask,
460
+ )
461
+ negative_prompt_embeds = negative_prompt_embeds[0]
462
+
463
+ if do_classifier_free_guidance:
464
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
465
+ seq_len = negative_prompt_embeds.shape[1]
466
+
467
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
468
+
469
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
470
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
471
+
472
+ # For classifier free guidance, we need to do two forward passes.
473
+ # Here we concatenate the unconditional and text embeddings into a single batch
474
+ # to avoid doing two forward passes
475
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
476
+
477
+ return prompt_embeds
478
+
479
+ def run_safety_checker(self, image, device, dtype):
480
+ if self.safety_checker is not None:
481
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
482
+ image, has_nsfw_concept = self.safety_checker(
483
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
484
+ )
485
+ else:
486
+ has_nsfw_concept = None
487
+ return image, has_nsfw_concept
488
+
489
+ def decode_latents(self, latents):
490
+ latents = 1 / self.vae.config.scaling_factor * latents
491
+ image = self.vae.decode(latents).sample
492
+ image = (image / 2 + 0.5).clamp(0, 1)
493
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
494
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
495
+ return image
496
+
497
+ def prepare_extra_step_kwargs(self, generator, eta):
498
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
499
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
500
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
501
+ # and should be between [0, 1]
502
+
503
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
504
+ extra_step_kwargs = {}
505
+ if accepts_eta:
506
+ extra_step_kwargs["eta"] = eta
507
+
508
+ # check if the scheduler accepts generator
509
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
510
+ if accepts_generator:
511
+ extra_step_kwargs["generator"] = generator
512
+ return extra_step_kwargs
513
+
514
+ def check_inputs(
515
+ self,
516
+ prompt,
517
+ image,
518
+ mask_image,
519
+ controlnet_conditioning_image,
520
+ height,
521
+ width,
522
+ callback_steps,
523
+ negative_prompt=None,
524
+ prompt_embeds=None,
525
+ negative_prompt_embeds=None,
526
+ strength=None,
527
+ ):
528
+ if height % 8 != 0 or width % 8 != 0:
529
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
530
+
531
+ if (callback_steps is None) or (
532
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
533
+ ):
534
+ raise ValueError(
535
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
536
+ f" {type(callback_steps)}."
537
+ )
538
+
539
+ if prompt is not None and prompt_embeds is not None:
540
+ raise ValueError(
541
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
542
+ " only forward one of the two."
543
+ )
544
+ elif prompt is None and prompt_embeds is None:
545
+ raise ValueError(
546
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
547
+ )
548
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
549
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
550
+
551
+ if negative_prompt is not None and negative_prompt_embeds is not None:
552
+ raise ValueError(
553
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
554
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
555
+ )
556
+
557
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
558
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
559
+ raise ValueError(
560
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
561
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
562
+ f" {negative_prompt_embeds.shape}."
563
+ )
564
+
565
+ controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
566
+ controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
567
+ controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
568
+ controlnet_conditioning_image[0], PIL.Image.Image
569
+ )
570
+ controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
571
+ controlnet_conditioning_image[0], torch.Tensor
572
+ )
573
+
574
+ if (
575
+ not controlnet_cond_image_is_pil
576
+ and not controlnet_cond_image_is_tensor
577
+ and not controlnet_cond_image_is_pil_list
578
+ and not controlnet_cond_image_is_tensor_list
579
+ ):
580
+ raise TypeError(
581
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
582
+ )
583
+
584
+ if controlnet_cond_image_is_pil:
585
+ controlnet_cond_image_batch_size = 1
586
+ elif controlnet_cond_image_is_tensor:
587
+ controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
588
+ elif controlnet_cond_image_is_pil_list:
589
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
590
+ elif controlnet_cond_image_is_tensor_list:
591
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
592
+
593
+ if prompt is not None and isinstance(prompt, str):
594
+ prompt_batch_size = 1
595
+ elif prompt is not None and isinstance(prompt, list):
596
+ prompt_batch_size = len(prompt)
597
+ elif prompt_embeds is not None:
598
+ prompt_batch_size = prompt_embeds.shape[0]
599
+
600
+ if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
601
+ raise ValueError(
602
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
603
+ )
604
+
605
+ if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
606
+ raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
607
+
608
+ if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
609
+ raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
610
+
611
+ if isinstance(image, torch.Tensor):
612
+ if image.ndim != 3 and image.ndim != 4:
613
+ raise ValueError("`image` must have 3 or 4 dimensions")
614
+
615
+ if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
616
+ raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
617
+
618
+ if image.ndim == 3:
619
+ image_batch_size = 1
620
+ image_channels, image_height, image_width = image.shape
621
+ elif image.ndim == 4:
622
+ image_batch_size, image_channels, image_height, image_width = image.shape
623
+
624
+ if mask_image.ndim == 2:
625
+ mask_image_batch_size = 1
626
+ mask_image_channels = 1
627
+ mask_image_height, mask_image_width = mask_image.shape
628
+ elif mask_image.ndim == 3:
629
+ mask_image_channels = 1
630
+ mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
631
+ elif mask_image.ndim == 4:
632
+ mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
633
+
634
+ if image_channels != 3:
635
+ raise ValueError("`image` must have 3 channels")
636
+
637
+ if mask_image_channels != 1:
638
+ raise ValueError("`mask_image` must have 1 channel")
639
+
640
+ if image_batch_size != mask_image_batch_size:
641
+ raise ValueError("`image` and `mask_image` mush have the same batch sizes")
642
+
643
+ if image_height != mask_image_height or image_width != mask_image_width:
644
+ raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
645
+
646
+ if image.min() < -1 or image.max() > 1:
647
+ raise ValueError("`image` should be in range [-1, 1]")
648
+
649
+ if mask_image.min() < 0 or mask_image.max() > 1:
650
+ raise ValueError("`mask_image` should be in range [0, 1]")
651
+ else:
652
+ mask_image_channels = 1
653
+ image_channels = 3
654
+
655
+ single_image_latent_channels = self.vae.config.latent_channels
656
+
657
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
658
+
659
+ if total_latent_channels != self.unet.config.in_channels:
660
+ raise ValueError(
661
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
662
+ f" non inpainting latent channels: {single_image_latent_channels},"
663
+ f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
664
+ f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
665
+ )
666
+
667
+ if strength < 0 or strength > 1:
668
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
669
+
670
+ def get_timesteps(self, num_inference_steps, strength, device):
671
+ # get the original timestep using init_timestep
672
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
673
+
674
+ t_start = max(num_inference_steps - init_timestep, 0)
675
+ timesteps = self.scheduler.timesteps[t_start:]
676
+
677
+ return timesteps, num_inference_steps - t_start
678
+
679
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
680
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
681
+ raise ValueError(
682
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
683
+ )
684
+
685
+ image = image.to(device=device, dtype=dtype)
686
+
687
+ batch_size = batch_size * num_images_per_prompt
688
+ if isinstance(generator, list) and len(generator) != batch_size:
689
+ raise ValueError(
690
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
691
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
692
+ )
693
+
694
+ if isinstance(generator, list):
695
+ init_latents = [
696
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
697
+ ]
698
+ init_latents = torch.cat(init_latents, dim=0)
699
+ else:
700
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
701
+
702
+ init_latents = self.vae.config.scaling_factor * init_latents
703
+
704
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
705
+ raise ValueError(
706
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
707
+ )
708
+ else:
709
+ init_latents = torch.cat([init_latents], dim=0)
710
+
711
+ shape = init_latents.shape
712
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
713
+
714
+ # get latents
715
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
716
+ latents = init_latents
717
+
718
+ return latents
719
+
720
+ def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
721
+ # resize the mask to latents shape as we concatenate the mask to the latents
722
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
723
+ # and half precision
724
+ mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
725
+ mask_image = mask_image.to(device=device, dtype=dtype)
726
+
727
+ # duplicate mask for each generation per prompt, using mps friendly method
728
+ if mask_image.shape[0] < batch_size:
729
+ if not batch_size % mask_image.shape[0] == 0:
730
+ raise ValueError(
731
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
732
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
733
+ " of masks that you pass is divisible by the total requested batch size."
734
+ )
735
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
736
+
737
+ mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
738
+
739
+ mask_image_latents = mask_image
740
+
741
+ return mask_image_latents
742
+
743
+ def prepare_masked_image_latents(
744
+ self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
745
+ ):
746
+ masked_image = masked_image.to(device=device, dtype=dtype)
747
+
748
+ # encode the mask image into latents space so we can concatenate it to the latents
749
+ if isinstance(generator, list):
750
+ masked_image_latents = [
751
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
752
+ for i in range(batch_size)
753
+ ]
754
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
755
+ else:
756
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
757
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
758
+
759
+ # duplicate masked_image_latents for each generation per prompt, using mps friendly method
760
+ if masked_image_latents.shape[0] < batch_size:
761
+ if not batch_size % masked_image_latents.shape[0] == 0:
762
+ raise ValueError(
763
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
764
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
765
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
766
+ )
767
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
768
+
769
+ masked_image_latents = (
770
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
771
+ )
772
+
773
+ # aligning device to prevent device errors when concating it with the latent model input
774
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
775
+ return masked_image_latents
776
+
777
+ def _default_height_width(self, height, width, image):
778
+ if isinstance(image, list):
779
+ image = image[0]
780
+
781
+ if height is None:
782
+ if isinstance(image, PIL.Image.Image):
783
+ height = image.height
784
+ elif isinstance(image, torch.Tensor):
785
+ height = image.shape[3]
786
+
787
+ height = (height // 8) * 8 # round down to nearest multiple of 8
788
+
789
+ if width is None:
790
+ if isinstance(image, PIL.Image.Image):
791
+ width = image.width
792
+ elif isinstance(image, torch.Tensor):
793
+ width = image.shape[2]
794
+
795
+ width = (width // 8) * 8 # round down to nearest multiple of 8
796
+
797
+ return height, width
798
+
799
+ @torch.no_grad()
800
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
801
+ def __call__(
802
+ self,
803
+ prompt: Union[str, List[str]] = None,
804
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
805
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
806
+ controlnet_conditioning_image: Union[
807
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
808
+ ] = None,
809
+ strength: float = 0.8,
810
+ height: Optional[int] = None,
811
+ width: Optional[int] = None,
812
+ num_inference_steps: int = 50,
813
+ guidance_scale: float = 7.5,
814
+ negative_prompt: Optional[Union[str, List[str]]] = None,
815
+ num_images_per_prompt: Optional[int] = 1,
816
+ eta: float = 0.0,
817
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
818
+ latents: Optional[torch.FloatTensor] = None,
819
+ prompt_embeds: Optional[torch.FloatTensor] = None,
820
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
821
+ output_type: Optional[str] = "pil",
822
+ return_dict: bool = True,
823
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
824
+ callback_steps: int = 1,
825
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
826
+ controlnet_conditioning_scale: float = 1.0,
827
+ controlnet_conditioning_scale_decay: float = 0.95,
828
+ controlnet_steps: int = 10,
829
+ ):
830
+ r"""
831
+ Function invoked when calling the pipeline for generation.
832
+ Args:
833
+ prompt (`str` or `List[str]`, *optional*):
834
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
835
+ instead.
836
+ image (`torch.Tensor` or `PIL.Image.Image`):
837
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
838
+ be masked out with `mask_image` and repainted according to `prompt`.
839
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
840
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
841
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
842
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
843
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
844
+ controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
845
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
846
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
847
+ also be accepted as an image. The control image is automatically resized to fit the output image.
848
+ strength (`float`, *optional*):
849
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
850
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
851
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
852
+ be maximum and the denoising process will run for the full number of iterations specified in
853
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
854
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
855
+ The height in pixels of the generated image.
856
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
857
+ The width in pixels of the generated image.
858
+ num_inference_steps (`int`, *optional*, defaults to 50):
859
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
860
+ expense of slower inference.
861
+ guidance_scale (`float`, *optional*, defaults to 7.5):
862
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
863
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
864
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
865
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
866
+ usually at the expense of lower image quality.
867
+ negative_prompt (`str` or `List[str]`, *optional*):
868
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
869
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
870
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
871
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
872
+ The number of images to generate per prompt.
873
+ eta (`float`, *optional*, defaults to 0.0):
874
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
875
+ [`schedulers.DDIMScheduler`], will be ignored for others.
876
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
877
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
878
+ to make generation deterministic.
879
+ latents (`torch.FloatTensor`, *optional*):
880
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
881
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
882
+ tensor will ge generated by sampling using the supplied random `generator`.
883
+ prompt_embeds (`torch.FloatTensor`, *optional*):
884
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
885
+ provided, text embeddings will be generated from `prompt` input argument.
886
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
887
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
888
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
889
+ argument.
890
+ output_type (`str`, *optional*, defaults to `"pil"`):
891
+ The output format of the generate image. Choose between
892
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
893
+ return_dict (`bool`, *optional*, defaults to `True`):
894
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
895
+ plain tuple.
896
+ callback (`Callable`, *optional*):
897
+ A function that will be called every `callback_steps` steps during inference. The function will be
898
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
899
+ callback_steps (`int`, *optional*, defaults to 1):
900
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
901
+ called at every step.
902
+ cross_attention_kwargs (`dict`, *optional*):
903
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
904
+ `self.processor` in
905
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
906
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
907
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
908
+ to the residual in the original unet.
909
+ Examples:
910
+ Returns:
911
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
912
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
913
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
914
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
915
+ (nsfw) content, according to the `safety_checker`.
916
+ """
917
+ # 0. Default height and width to unet
918
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
919
+
920
+ # 1. Check inputs. Raise error if not correct
921
+ self.check_inputs(
922
+ prompt,
923
+ image,
924
+ mask_image,
925
+ controlnet_conditioning_image,
926
+ height,
927
+ width,
928
+ callback_steps,
929
+ negative_prompt,
930
+ prompt_embeds,
931
+ negative_prompt_embeds,
932
+ strength,
933
+ )
934
+
935
+ # 2. Define call parameters
936
+ if prompt is not None and isinstance(prompt, str):
937
+ batch_size = 1
938
+ elif prompt is not None and isinstance(prompt, list):
939
+ batch_size = len(prompt)
940
+ else:
941
+ batch_size = prompt_embeds.shape[0]
942
+
943
+ device = self._execution_device
944
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
945
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
946
+ # corresponds to doing no classifier free guidance.
947
+ do_classifier_free_guidance = guidance_scale > 1.0
948
+
949
+ # 3. Encode input prompt
950
+ prompt_embeds = self._encode_prompt(
951
+ prompt,
952
+ device,
953
+ num_images_per_prompt,
954
+ do_classifier_free_guidance,
955
+ negative_prompt,
956
+ prompt_embeds=prompt_embeds,
957
+ negative_prompt_embeds=negative_prompt_embeds,
958
+ )
959
+
960
+ # 4. Prepare mask, image, and controlnet_conditioning_image
961
+ image = prepare_image(image)
962
+
963
+ mask_image = prepare_mask_image(mask_image)
964
+
965
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
966
+ controlnet_conditioning_image,
967
+ width,
968
+ height,
969
+ batch_size * num_images_per_prompt,
970
+ num_images_per_prompt,
971
+ device,
972
+ self.controlnet.dtype,
973
+ )
974
+
975
+ masked_image = image * (mask_image < 0.5)
976
+
977
+ # 5. Prepare timesteps
978
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
979
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
980
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
981
+
982
+ # 6. Prepare latent variables
983
+ latents = self.prepare_latents(
984
+ image,
985
+ latent_timestep,
986
+ batch_size,
987
+ num_images_per_prompt,
988
+ prompt_embeds.dtype,
989
+ device,
990
+ generator,
991
+ )
992
+
993
+ mask_image_latents = self.prepare_mask_latents(
994
+ mask_image,
995
+ batch_size * num_images_per_prompt,
996
+ height,
997
+ width,
998
+ prompt_embeds.dtype,
999
+ device,
1000
+ do_classifier_free_guidance,
1001
+ )
1002
+
1003
+ masked_image_latents = self.prepare_masked_image_latents(
1004
+ masked_image,
1005
+ batch_size * num_images_per_prompt,
1006
+ height,
1007
+ width,
1008
+ prompt_embeds.dtype,
1009
+ device,
1010
+ generator,
1011
+ do_classifier_free_guidance,
1012
+ )
1013
+
1014
+ if do_classifier_free_guidance:
1015
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
1016
+
1017
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1018
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1019
+
1020
+ # 8. Denoising loop
1021
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1022
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1023
+ for i, t in enumerate(timesteps):
1024
+ # expand the latents if we are doing classifier free guidance
1025
+ non_inpainting_latent_model_input = (
1026
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1027
+ )
1028
+
1029
+ non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1030
+ non_inpainting_latent_model_input, t
1031
+ )
1032
+
1033
+ inpainting_latent_model_input = torch.cat(
1034
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1035
+ )
1036
+
1037
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1038
+ non_inpainting_latent_model_input,
1039
+ t,
1040
+ encoder_hidden_states=prompt_embeds,
1041
+ controlnet_cond=controlnet_conditioning_image,
1042
+ return_dict=False,
1043
+ )
1044
+ if i <= controlnet_steps:
1045
+ conditioning_scale = (controlnet_conditioning_scale * controlnet_conditioning_scale_decay ** i)
1046
+ else:
1047
+ conditioning_scale = 0.0
1048
+
1049
+ down_block_res_samples = [
1050
+ down_block_res_sample * conditioning_scale
1051
+ for down_block_res_sample in down_block_res_samples
1052
+ ]
1053
+ mid_block_res_sample *= conditioning_scale
1054
+
1055
+ # predict the noise residual
1056
+ noise_pred = self.unet(
1057
+ inpainting_latent_model_input,
1058
+ t,
1059
+ encoder_hidden_states=prompt_embeds,
1060
+ cross_attention_kwargs=cross_attention_kwargs,
1061
+ down_block_additional_residuals=down_block_res_samples,
1062
+ mid_block_additional_residual=mid_block_res_sample,
1063
+ ).sample
1064
+
1065
+ # perform guidance
1066
+ if do_classifier_free_guidance:
1067
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1068
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1069
+
1070
+ # compute the previous noisy sample x_t -> x_t-1
1071
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1072
+
1073
+ # call the callback, if provided
1074
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1075
+ progress_bar.update()
1076
+ if callback is not None and i % callback_steps == 0:
1077
+ callback(i, t, latents)
1078
+
1079
+ # If we do sequential model offloading, let's offload unet and controlnet
1080
+ # manually for max memory savings
1081
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1082
+ self.unet.to("cpu")
1083
+ self.controlnet.to("cpu")
1084
+ torch.cuda.empty_cache()
1085
+
1086
+ if output_type == "latent":
1087
+ image = latents
1088
+ has_nsfw_concept = None
1089
+ elif output_type == "pil":
1090
+ # 8. Post-processing
1091
+ image = self.decode_latents(latents)
1092
+
1093
+ # 9. Run safety checker
1094
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1095
+
1096
+ # 10. Convert to PIL
1097
+ image = self.numpy_to_pil(image)
1098
+ else:
1099
+ # 8. Post-processing
1100
+ image = self.decode_latents(latents)
1101
+
1102
+ # 9. Run safety checker
1103
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1104
+
1105
+ # Offload last model to CPU
1106
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1107
+ self.final_offload_hook.offload()
1108
+
1109
+ if not return_dict:
1110
+ return (image, has_nsfw_concept)
1111
+
1112
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)