DmitrMakeev commited on
Commit
b1342b4
·
verified ·
1 Parent(s): 15d6587

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +374 -0
model.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ import torch
8
+ from controlnet_aux import (
9
+ CannyDetector,
10
+ LineartDetector,
11
+ MidasDetector,
12
+ OpenposeDetector,
13
+ PidiNetDetector,
14
+ ZoeDetector,
15
+ )
16
+ from diffusers import (
17
+ AutoencoderKL,
18
+ EulerAncestralDiscreteScheduler,
19
+ StableDiffusionXLAdapterPipeline,
20
+ T2IAdapter,
21
+ )
22
+
23
+ SD_XL_BASE_RATIOS = {
24
+ "0.5": (704, 1408),
25
+ "0.52": (704, 1344),
26
+ "0.57": (768, 1344),
27
+ "0.6": (768, 1280),
28
+ "0.68": (832, 1216),
29
+ "0.72": (832, 1152),
30
+ "0.78": (896, 1152),
31
+ "0.82": (896, 1088),
32
+ "0.88": (960, 1088),
33
+ "0.94": (960, 1024),
34
+ "1.0": (1024, 1024),
35
+ "1.07": (1024, 960),
36
+ "1.13": (1088, 960),
37
+ "1.21": (1088, 896),
38
+ "1.29": (1152, 896),
39
+ "1.38": (1152, 832),
40
+ "1.46": (1216, 832),
41
+ "1.67": (1280, 768),
42
+ "1.75": (1344, 768),
43
+ "1.91": (1344, 704),
44
+ "2.0": (1408, 704),
45
+ "2.09": (1472, 704),
46
+ "2.4": (1536, 640),
47
+ "2.5": (1600, 640),
48
+ "2.89": (1664, 576),
49
+ "3.0": (1728, 576),
50
+ }
51
+
52
+
53
+ def find_closest_aspect_ratio(target_width: int, target_height: int) -> str:
54
+ target_ratio = target_width / target_height
55
+ closest_ratio = ""
56
+ min_difference = float("inf")
57
+
58
+ for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
59
+ ratio = width / height
60
+ difference = abs(target_ratio - ratio)
61
+
62
+ if difference < min_difference:
63
+ min_difference = difference
64
+ closest_ratio = ratio_str
65
+
66
+ return closest_ratio
67
+
68
+
69
+ def resize_to_closest_aspect_ratio(image: PIL.Image.Image) -> PIL.Image.Image:
70
+ target_width, target_height = image.size
71
+ closest_ratio = find_closest_aspect_ratio(target_width, target_height)
72
+
73
+ # Get the dimensions from the closest aspect ratio in the dictionary
74
+ new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
75
+
76
+ # Resize the image to the new dimensions while preserving the aspect ratio
77
+ resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS)
78
+
79
+ return resized_image
80
+
81
+
82
+ ADAPTER_REPO_IDS = {
83
+ "canny": "TencentARC/t2i-adapter-canny-sdxl-1.0",
84
+ "sketch": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
85
+ "lineart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
86
+ "depth-midas": "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
87
+ "depth-zoe": "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
88
+ "openpose": "TencentARC/t2i-adapter-openpose-sdxl-1.0",
89
+ # "recolor": "TencentARC/t2i-adapter-recolor-sdxl-1.0",
90
+ }
91
+ ADAPTER_NAMES = list(ADAPTER_REPO_IDS.keys())
92
+
93
+
94
+ class Preprocessor(ABC):
95
+ @abstractmethod
96
+ def to(self, device: torch.device | str) -> "Preprocessor":
97
+ pass
98
+
99
+ @abstractmethod
100
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
101
+ pass
102
+
103
+
104
+ class CannyPreprocessor(Preprocessor):
105
+ def __init__(self):
106
+ self.model = CannyDetector()
107
+
108
+ def to(self, device: torch.device | str) -> Preprocessor:
109
+ return self
110
+
111
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
112
+ return self.model(image, detect_resolution=384, image_resolution=1024)
113
+
114
+
115
+ class LineartPreprocessor(Preprocessor):
116
+ def __init__(self):
117
+ self.model = LineartDetector.from_pretrained("lllyasviel/Annotators")
118
+
119
+ def to(self, device: torch.device | str) -> Preprocessor:
120
+ self.model.to(device)
121
+ return self
122
+
123
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
124
+ return self.model(image, detect_resolution=384, image_resolution=1024)
125
+
126
+
127
+ class MidasPreprocessor(Preprocessor):
128
+ def __init__(self):
129
+ self.model = MidasDetector.from_pretrained(
130
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
131
+ )
132
+
133
+ def to(self, device: torch.device | str) -> Preprocessor:
134
+ self.model.to(device)
135
+ return self
136
+
137
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
138
+ return self.model(image, detect_resolution=512, image_resolution=1024)
139
+
140
+
141
+ class OpenposePreprocessor(Preprocessor):
142
+ def __init__(self):
143
+ self.model = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
144
+
145
+ def to(self, device: torch.device | str) -> Preprocessor:
146
+ self.model.to(device)
147
+ return self
148
+
149
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
150
+ out = self.model(image, detect_resolution=512, image_resolution=1024)
151
+ out = np.array(out)[:, :, ::-1]
152
+ out = PIL.Image.fromarray(np.uint8(out))
153
+ return out
154
+
155
+
156
+ class PidiNetPreprocessor(Preprocessor):
157
+ def __init__(self):
158
+ self.model = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
159
+
160
+ def to(self, device: torch.device | str) -> Preprocessor:
161
+ self.model.to(device)
162
+ return self
163
+
164
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
165
+ return self.model(image, detect_resolution=512, image_resolution=1024, apply_filter=True)
166
+
167
+
168
+ class RecolorPreprocessor(Preprocessor):
169
+ def to(self, device: torch.device | str) -> Preprocessor:
170
+ return self
171
+
172
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
173
+ return image.convert("L").convert("RGB")
174
+
175
+
176
+ class ZoePreprocessor(Preprocessor):
177
+ def __init__(self):
178
+ self.model = ZoeDetector.from_pretrained(
179
+ "valhalla/t2iadapter-aux-models", filename="zoed_nk.pth", model_type="zoedepth_nk"
180
+ )
181
+
182
+ def to(self, device: torch.device | str) -> Preprocessor:
183
+ self.model.to(device)
184
+ return self
185
+
186
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
187
+ return self.model(image, gamma_corrected=True, image_resolution=1024)
188
+
189
+
190
+ PRELOAD_PREPROCESSORS_IN_GPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_GPU_MEMORY", "0") == "1"
191
+ PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_MEMORY", "0") == "1"
192
+ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
193
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194
+ preprocessors_gpu: dict[str, Preprocessor] = {
195
+ "canny": CannyPreprocessor().to(device),
196
+ "sketch": PidiNetPreprocessor().to(device),
197
+ "lineart": LineartPreprocessor().to(device),
198
+ "depth-midas": MidasPreprocessor().to(device),
199
+ "depth-zoe": ZoePreprocessor().to(device),
200
+ "openpose": OpenposePreprocessor().to(device),
201
+ "recolor": RecolorPreprocessor().to(device),
202
+ }
203
+
204
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
205
+ return preprocessors_gpu[adapter_name]
206
+
207
+ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
208
+ preprocessors_cpu: dict[str, Preprocessor] = {
209
+ "canny": CannyPreprocessor(),
210
+ "sketch": PidiNetPreprocessor(),
211
+ "lineart": LineartPreprocessor(),
212
+ "depth-midas": MidasPreprocessor(),
213
+ "depth-zoe": ZoePreprocessor(),
214
+ "openpose": OpenposePreprocessor(),
215
+ "recolor": RecolorPreprocessor(),
216
+ }
217
+
218
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
219
+ return preprocessors_cpu[adapter_name]
220
+
221
+ else:
222
+
223
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
224
+ if adapter_name == "canny":
225
+ return CannyPreprocessor()
226
+ elif adapter_name == "sketch":
227
+ return PidiNetPreprocessor()
228
+ elif adapter_name == "lineart":
229
+ return LineartPreprocessor()
230
+ elif adapter_name == "depth-midas":
231
+ return MidasPreprocessor()
232
+ elif adapter_name == "depth-zoe":
233
+ return ZoePreprocessor()
234
+ elif adapter_name == "openpose":
235
+ return OpenposePreprocessor()
236
+ elif adapter_name == "recolor":
237
+ return RecolorPreprocessor()
238
+ else:
239
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
240
+
241
+ def download_all_preprocessors():
242
+ for adapter_name in ADAPTER_NAMES:
243
+ get_preprocessor(adapter_name)
244
+ gc.collect()
245
+
246
+ download_all_preprocessors()
247
+
248
+
249
+ def download_all_adapters():
250
+ for adapter_name in ADAPTER_NAMES:
251
+ T2IAdapter.from_pretrained(
252
+ ADAPTER_REPO_IDS[adapter_name],
253
+ torch_dtype=torch.float16,
254
+ varient="fp16",
255
+ )
256
+ gc.collect()
257
+
258
+
259
+ class Model:
260
+ MAX_NUM_INFERENCE_STEPS = 50
261
+
262
+ def __init__(self, adapter_name: str):
263
+ if adapter_name not in ADAPTER_NAMES:
264
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
265
+
266
+ self.preprocessor_name = adapter_name
267
+ self.adapter_name = adapter_name
268
+
269
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270
+ if torch.cuda.is_available():
271
+ self.preprocessor = get_preprocessor(adapter_name).to(self.device)
272
+
273
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
274
+ adapter = T2IAdapter.from_pretrained(
275
+ ADAPTER_REPO_IDS[adapter_name],
276
+ torch_dtype=torch.float16,
277
+ varient="fp16",
278
+ ).to(self.device)
279
+ self.pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
280
+ model_id,
281
+ vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
282
+ adapter=adapter,
283
+ scheduler=EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"),
284
+ torch_dtype=torch.float16,
285
+ variant="fp16",
286
+ ).to(self.device)
287
+ self.pipe.enable_xformers_memory_efficient_attention()
288
+ self.pipe.load_lora_weights(
289
+ "stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors"
290
+ )
291
+ self.pipe.fuse_lora(lora_scale=0.4)
292
+ else:
293
+ self.preprocessor = None # type: ignore
294
+ self.pipe = None
295
+
296
+ def change_preprocessor(self, adapter_name: str) -> None:
297
+ if adapter_name not in ADAPTER_NAMES:
298
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
299
+ if adapter_name == self.preprocessor_name:
300
+ return
301
+
302
+ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
303
+ pass
304
+ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
305
+ self.preprocessor.to("cpu")
306
+ else:
307
+ del self.preprocessor
308
+ self.preprocessor = get_preprocessor(adapter_name).to(self.device)
309
+ self.preprocessor_name = adapter_name
310
+ gc.collect()
311
+ torch.cuda.empty_cache()
312
+
313
+ def change_adapter(self, adapter_name: str) -> None:
314
+ if adapter_name not in ADAPTER_NAMES:
315
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
316
+ if adapter_name == self.adapter_name:
317
+ return
318
+ self.pipe.adapter = T2IAdapter.from_pretrained(
319
+ ADAPTER_REPO_IDS[adapter_name],
320
+ torch_dtype=torch.float16,
321
+ varient="fp16",
322
+ ).to(self.device)
323
+ self.adapter_name = adapter_name
324
+ gc.collect()
325
+ torch.cuda.empty_cache()
326
+
327
+ def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image:
328
+ w, h = image.size
329
+ scale = 1024 / max(w, h)
330
+ new_w = int(w * scale)
331
+ new_h = int(h * scale)
332
+ return image.resize((new_w, new_h), PIL.Image.LANCZOS)
333
+
334
+ def run(
335
+ self,
336
+ image: PIL.Image.Image,
337
+ prompt: str,
338
+ negative_prompt: str,
339
+ adapter_name: str,
340
+ num_inference_steps: int = 30,
341
+ guidance_scale: float = 5.0,
342
+ adapter_conditioning_scale: float = 1.0,
343
+ adapter_conditioning_factor: float = 1.0,
344
+ seed: int = 0,
345
+ apply_preprocess: bool = True,
346
+ ) -> list[PIL.Image.Image]:
347
+ if not torch.cuda.is_available():
348
+ raise RuntimeError("This demo does not work on CPU.")
349
+ if num_inference_steps > self.MAX_NUM_INFERENCE_STEPS:
350
+ raise ValueError(f"Number of steps must be less than {self.MAX_NUM_INFERENCE_STEPS}")
351
+
352
+ # Resize image to avoid OOM
353
+ image = self.resize_image(image)
354
+
355
+ self.change_preprocessor(adapter_name)
356
+ self.change_adapter(adapter_name)
357
+
358
+ if apply_preprocess:
359
+ image = self.preprocessor(image)
360
+
361
+ image = resize_to_closest_aspect_ratio(image)
362
+
363
+ generator = torch.Generator(device=self.device).manual_seed(seed)
364
+ out = self.pipe(
365
+ prompt=prompt,
366
+ negative_prompt=negative_prompt,
367
+ image=image,
368
+ num_inference_steps=num_inference_steps,
369
+ adapter_conditioning_scale=adapter_conditioning_scale,
370
+ adapter_conditioning_factor=adapter_conditioning_factor,
371
+ generator=generator,
372
+ guidance_scale=guidance_scale,
373
+ ).images[0]
374
+ return [image, out]