abhishek HF staff commited on
Commit
58f667f
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +3 -0
  3. README.md +13 -0
  4. __pycache__/config.cpython-310.pyc +0 -0
  5. __pycache__/model.cpython-310.pyc +0 -0
  6. __pycache__/utils.cpython-310.pyc +0 -0
  7. _app.py +1360 -0
  8. annotator/__pycache__/util.cpython-310.pyc +0 -0
  9. annotator/__pycache__/util.cpython-38.pyc +0 -0
  10. annotator/blur/__init__.py +7 -0
  11. annotator/blur/__pycache__/__init__.cpython-310.pyc +0 -0
  12. annotator/blur/__pycache__/__init__.cpython-38.pyc +0 -0
  13. annotator/canny/__init__.py +16 -0
  14. annotator/canny/__pycache__/__init__.cpython-310.pyc +0 -0
  15. annotator/canny/__pycache__/__init__.cpython-38.pyc +0 -0
  16. annotator/ckpts/ckpts.txt +1 -0
  17. annotator/grayscale/__init__.py +5 -0
  18. annotator/grayscale/__pycache__/__init__.cpython-310.pyc +0 -0
  19. annotator/grayscale/__pycache__/__init__.cpython-38.pyc +0 -0
  20. annotator/hed/__init__.py +107 -0
  21. annotator/hed/__pycache__/__init__.cpython-310.pyc +0 -0
  22. annotator/hed/__pycache__/__init__.cpython-38.pyc +0 -0
  23. annotator/inpainting/__init__.py +16 -0
  24. annotator/inpainting/__pycache__/__init__.cpython-310.pyc +0 -0
  25. annotator/inpainting/__pycache__/__init__.cpython-38.pyc +0 -0
  26. annotator/midas/LICENSE +21 -0
  27. annotator/midas/__init__.py +52 -0
  28. annotator/midas/__pycache__/__init__.cpython-310.pyc +0 -0
  29. annotator/midas/__pycache__/__init__.cpython-38.pyc +0 -0
  30. annotator/midas/__pycache__/api.cpython-310.pyc +0 -0
  31. annotator/midas/__pycache__/api.cpython-38.pyc +0 -0
  32. annotator/midas/api.py +183 -0
  33. annotator/midas/midas/__init__.py +0 -0
  34. annotator/midas/midas/__pycache__/__init__.cpython-310.pyc +0 -0
  35. annotator/midas/midas/__pycache__/__init__.cpython-38.pyc +0 -0
  36. annotator/midas/midas/__pycache__/base_model.cpython-310.pyc +0 -0
  37. annotator/midas/midas/__pycache__/base_model.cpython-38.pyc +0 -0
  38. annotator/midas/midas/__pycache__/blocks.cpython-310.pyc +0 -0
  39. annotator/midas/midas/__pycache__/blocks.cpython-38.pyc +0 -0
  40. annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc +0 -0
  41. annotator/midas/midas/__pycache__/dpt_depth.cpython-38.pyc +0 -0
  42. annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc +0 -0
  43. annotator/midas/midas/__pycache__/midas_net.cpython-38.pyc +0 -0
  44. annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc +0 -0
  45. annotator/midas/midas/__pycache__/midas_net_custom.cpython-38.pyc +0 -0
  46. annotator/midas/midas/__pycache__/transforms.cpython-310.pyc +0 -0
  47. annotator/midas/midas/__pycache__/transforms.cpython-38.pyc +0 -0
  48. annotator/midas/midas/__pycache__/vit.cpython-310.pyc +0 -0
  49. annotator/midas/midas/__pycache__/vit.cpython-38.pyc +0 -0
  50. annotator/midas/midas/base_model.py +26 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pt
2
+ *.pth
3
+ *.st
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: UniControl Demo
3
+ emoji: 📚
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/config.cpython-310.pyc ADDED
Binary file (532 Bytes). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.46 kB). View file
 
_app.py ADDED
@@ -0,0 +1,1360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ import config
12
+
13
+ import cv2
14
+ import einops
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import random
19
+ import os
20
+
21
+ from pytorch_lightning import seed_everything
22
+ from annotator.util import resize_image, HWC3
23
+ from annotator.uniformer_base import UniformerDetector
24
+ from annotator.hed import HEDdetector
25
+ from annotator.canny import CannyDetector
26
+ from annotator.midas import MidasDetector
27
+ from annotator.outpainting import Outpainter
28
+ from annotator.openpose import OpenposeDetector
29
+ from annotator.inpainting import Inpainter
30
+ from annotator.grayscale import GrayscaleConverter
31
+ from annotator.blur import Blurrer
32
+ import cvlib as cv
33
+
34
+ from utils import create_model, load_state_dict
35
+ from lib.ddim_hacked import DDIMSampler
36
+
37
+ from safetensors.torch import load_file as stload
38
+ from collections import OrderedDict
39
+
40
+ apply_uniformer = UniformerDetector()
41
+ apply_midas = MidasDetector()
42
+ apply_canny = CannyDetector()
43
+ apply_hed = HEDdetector()
44
+ model_outpainting = Outpainter()
45
+ apply_openpose = OpenposeDetector()
46
+ model_grayscale = GrayscaleConverter()
47
+ model_blur = Blurrer()
48
+ model_inpainting = Inpainter()
49
+
50
+
51
+ def midas(img, res):
52
+ img = resize_image(HWC3(img), res)
53
+ results = apply_midas(img)
54
+ return results
55
+
56
+
57
+ def outpainting(img, res, height_top_extended, height_down_extended, width_left_extended, width_right_extended):
58
+ img = resize_image(HWC3(img), res)
59
+ result = model_outpainting(img, height_top_extended, height_down_extended, width_left_extended, width_right_extended)
60
+ return result
61
+
62
+
63
+ def grayscale(img, res):
64
+ img = resize_image(HWC3(img), res)
65
+ result = model_grayscale(img)
66
+ return result
67
+
68
+
69
+ def blur(img, res, ksize):
70
+ img = resize_image(HWC3(img), res)
71
+ result = model_blur(img, ksize)
72
+ return result
73
+
74
+
75
+ def inpainting(img, res, height_top_mask, height_down_mask, width_left_mask, width_right_mask):
76
+ img = resize_image(HWC3(img), res)
77
+ result = model_inpainting(img, height_top_mask, height_down_mask, width_left_mask, width_right_mask)
78
+ return result
79
+
80
+ model = create_model('./models/cldm_v15_unicontrol.yaml').cpu()
81
+ # model_url = 'https://huggingface.co/Robert001/UniControl-Model/resolve/main/unicontrol_v1.1.ckpt'
82
+ model_url = 'https://huggingface.co/Robert001/UniControl-Model/resolve/main/unicontrol_v1.1.st'
83
+
84
+ ckpts_path='./'
85
+ # model_path = os.path.join(ckpts_path, "unicontrol_v1.1.ckpt")
86
+ model_path = os.path.join(ckpts_path, "unicontrol_v1.1.st")
87
+
88
+ if not os.path.exists(model_path):
89
+ from basicsr.utils.download_util import load_file_from_url
90
+ load_file_from_url(model_url, model_dir=ckpts_path)
91
+
92
+ model_dict = OrderedDict(stload(model_path, device='cpu'))
93
+ model.load_state_dict(model_dict, strict=False)
94
+ # model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
95
+ model = model.cuda()
96
+ ddim_sampler = DDIMSampler(model)
97
+
98
+ task_to_name = {'hed': 'control_hed', 'canny': 'control_canny', 'seg': 'control_seg', 'segbase': 'control_seg',
99
+ 'depth': 'control_depth', 'normal': 'control_normal', 'openpose': 'control_openpose',
100
+ 'bbox': 'control_bbox', 'grayscale': 'control_grayscale', 'outpainting': 'control_outpainting',
101
+ 'hedsketch': 'control_hedsketch', 'inpainting': 'control_inpainting', 'blur': 'control_blur',
102
+ 'grayscale': 'control_grayscale'}
103
+
104
+ name_to_instruction = {"control_hed": "hed edge to image", "control_canny": "canny edge to image",
105
+ "control_seg": "segmentation map to image", "control_depth": "depth map to image",
106
+ "control_normal": "normal surface map to image", "control_img": "image editing",
107
+ "control_openpose": "human pose skeleton to image", "control_hedsketch": "sketch to image",
108
+ "control_bbox": "bounding box to image", "control_outpainting": "image outpainting",
109
+ "control_grayscale": "gray image to color image", "control_blur": "deblur image to clean image",
110
+ "control_inpainting": "image inpainting"}
111
+
112
+
113
+ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
114
+ strength, scale, seed, eta, low_threshold, high_threshold, condition_mode):
115
+ with torch.no_grad():
116
+ img = resize_image(HWC3(input_image), image_resolution)
117
+ H, W, C = img.shape
118
+ if condition_mode == True:
119
+ detected_map = apply_canny(img, low_threshold, high_threshold)
120
+ detected_map = HWC3(detected_map)
121
+ else:
122
+ detected_map = 255 - img
123
+
124
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
125
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
126
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
127
+
128
+ if seed == -1:
129
+ seed = random.randint(0, 65535)
130
+ seed_everything(seed)
131
+
132
+ if config.save_memory:
133
+ model.low_vram_shift(is_diffusing=False)
134
+ task = 'canny'
135
+ task_dic = {}
136
+ task_dic['name'] = task_to_name[task]
137
+ task_instruction = name_to_instruction[task_dic['name']]
138
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
139
+
140
+ cond = {"c_concat": [control],
141
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
142
+ "task": task_dic}
143
+
144
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
145
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
146
+ shape = (4, H // 8, W // 8)
147
+
148
+ if config.save_memory:
149
+ model.low_vram_shift(is_diffusing=True)
150
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
151
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
152
+ shape, cond, verbose=False, eta=eta,
153
+ unconditional_guidance_scale=scale,
154
+ unconditional_conditioning=un_cond)
155
+
156
+ if config.save_memory:
157
+ model.low_vram_shift(is_diffusing=False)
158
+
159
+ x_samples = model.decode_first_stage(samples)
160
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
161
+ 255).astype(
162
+ np.uint8)
163
+
164
+ results = [x_samples[i] for i in range(num_samples)]
165
+ return [255 - detected_map] + results
166
+
167
+
168
+ def process_hed(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
169
+ guess_mode, strength, scale, seed, eta, condition_mode):
170
+ with torch.no_grad():
171
+ input_image = HWC3(input_image)
172
+ img = resize_image(input_image, image_resolution)
173
+ H, W, C = img.shape
174
+ if condition_mode == True:
175
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
176
+ detected_map = HWC3(detected_map)
177
+ else:
178
+ detected_map = img
179
+
180
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
181
+
182
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
183
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
184
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
185
+
186
+ if seed == -1:
187
+ seed = random.randint(0, 65535)
188
+ seed_everything(seed)
189
+
190
+ if config.save_memory:
191
+ model.low_vram_shift(is_diffusing=False)
192
+
193
+ task = 'hed'
194
+ task_dic = {}
195
+ task_dic['name'] = task_to_name[task]
196
+ task_instruction = name_to_instruction[task_dic['name']]
197
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
198
+
199
+ cond = {"c_concat": [control],
200
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
201
+ "task": task_dic}
202
+
203
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
204
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
205
+ shape = (4, H // 8, W // 8)
206
+
207
+ if config.save_memory:
208
+ model.low_vram_shift(is_diffusing=True)
209
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
210
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
211
+ shape, cond, verbose=False, eta=eta,
212
+ unconditional_guidance_scale=scale,
213
+ unconditional_conditioning=un_cond)
214
+
215
+ if config.save_memory:
216
+ model.low_vram_shift(is_diffusing=False)
217
+
218
+ x_samples = model.decode_first_stage(samples)
219
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
220
+ 255).astype(
221
+ np.uint8)
222
+
223
+ results = [x_samples[i] for i in range(num_samples)]
224
+ return [detected_map] + results
225
+
226
+
227
+ def process_depth(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
228
+ guess_mode, strength, scale, seed, eta, condition_mode):
229
+ with torch.no_grad():
230
+ input_image = HWC3(input_image)
231
+ img = resize_image(input_image, image_resolution)
232
+ H, W, C = img.shape
233
+ if condition_mode == True:
234
+ detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
235
+ detected_map = HWC3(detected_map)
236
+ else:
237
+ detected_map = img
238
+
239
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
240
+
241
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
242
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
243
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
244
+
245
+ if seed == -1:
246
+ seed = random.randint(0, 65535)
247
+ seed_everything(seed)
248
+
249
+ if config.save_memory:
250
+ model.low_vram_shift(is_diffusing=False)
251
+ task = 'depth'
252
+ task_dic = {}
253
+ task_dic['name'] = task_to_name[task]
254
+ task_instruction = name_to_instruction[task_dic['name']]
255
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
256
+ cond = {"c_concat": [control],
257
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
258
+ "task": task_dic}
259
+
260
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
261
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
262
+ shape = (4, H // 8, W // 8)
263
+
264
+ if config.save_memory:
265
+ model.low_vram_shift(is_diffusing=True)
266
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
267
+ [strength] * 13)
268
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
269
+ shape, cond, verbose=False, eta=eta,
270
+ unconditional_guidance_scale=scale,
271
+ unconditional_conditioning=un_cond)
272
+
273
+ if config.save_memory:
274
+ model.low_vram_shift(is_diffusing=False)
275
+
276
+ x_samples = model.decode_first_stage(samples)
277
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
278
+ 255).astype(
279
+ np.uint8)
280
+
281
+ results = [x_samples[i] for i in range(num_samples)]
282
+ return [detected_map] + results
283
+
284
+
285
+ def process_normal(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
286
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode):
287
+ with torch.no_grad():
288
+
289
+ input_image = HWC3(input_image)
290
+ img = resize_image(input_image, image_resolution)
291
+ H, W, C = img.shape
292
+ if condition_mode == True:
293
+ _, detected_map = apply_midas(resize_image(input_image, detect_resolution))
294
+ detected_map = HWC3(detected_map)
295
+ else:
296
+ detected_map = img
297
+
298
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
299
+
300
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
301
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
302
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
303
+
304
+ if seed == -1:
305
+ seed = random.randint(0, 65535)
306
+ seed_everything(seed)
307
+
308
+ if config.save_memory:
309
+ model.low_vram_shift(is_diffusing=False)
310
+ task = 'normal'
311
+ task_dic = {}
312
+ task_dic['name'] = task_to_name[task]
313
+ task_instruction = name_to_instruction[task_dic['name']]
314
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
315
+ cond = {"c_concat": [control],
316
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
317
+ "task": task_dic}
318
+
319
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
320
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
321
+ shape = (4, H // 8, W // 8)
322
+
323
+ if config.save_memory:
324
+ model.low_vram_shift(is_diffusing=True)
325
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
326
+ [strength] * 13)
327
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
328
+ shape, cond, verbose=False, eta=eta,
329
+ unconditional_guidance_scale=scale,
330
+ unconditional_conditioning=un_cond)
331
+
332
+ if config.save_memory:
333
+ model.low_vram_shift(is_diffusing=False)
334
+
335
+ x_samples = model.decode_first_stage(samples)
336
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
337
+ 255).astype(
338
+ np.uint8)
339
+
340
+ results = [x_samples[i] for i in range(num_samples)]
341
+ return [detected_map] + results
342
+
343
+
344
+ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
345
+ guess_mode, strength, scale, seed, eta, condition_mode):
346
+ with torch.no_grad():
347
+ input_image = HWC3(input_image)
348
+ img = resize_image(input_image, image_resolution)
349
+ H, W, C = img.shape
350
+ if condition_mode == True:
351
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
352
+ detected_map = HWC3(detected_map)
353
+ else:
354
+ detected_map = img
355
+
356
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
357
+
358
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
359
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
360
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
361
+
362
+ if seed == -1:
363
+ seed = random.randint(0, 65535)
364
+ seed_everything(seed)
365
+
366
+ if config.save_memory:
367
+ model.low_vram_shift(is_diffusing=False)
368
+ task = 'openpose'
369
+ task_dic = {}
370
+ task_dic['name'] = task_to_name[task]
371
+ task_instruction = name_to_instruction[task_dic['name']]
372
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
373
+ cond = {"c_concat": [control],
374
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
375
+ "task": task_dic}
376
+
377
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
378
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
379
+ shape = (4, H // 8, W // 8)
380
+
381
+ if config.save_memory:
382
+ model.low_vram_shift(is_diffusing=True)
383
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
384
+ [strength] * 13)
385
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
386
+ shape, cond, verbose=False, eta=eta,
387
+ unconditional_guidance_scale=scale,
388
+ unconditional_conditioning=un_cond)
389
+
390
+ if config.save_memory:
391
+ model.low_vram_shift(is_diffusing=False)
392
+
393
+ x_samples = model.decode_first_stage(samples)
394
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
395
+ 255).astype(
396
+ np.uint8)
397
+
398
+ results = [x_samples[i] for i in range(num_samples)]
399
+ return [detected_map] + results
400
+
401
+
402
+ def process_seg(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
403
+ guess_mode, strength, scale, seed, eta, condition_mode):
404
+ with torch.no_grad():
405
+ input_image = HWC3(input_image)
406
+ img = resize_image(input_image, image_resolution)
407
+ H, W, C = img.shape
408
+
409
+ if condition_mode == True:
410
+ detected_map = apply_uniformer(resize_image(input_image, detect_resolution))
411
+ else:
412
+ detected_map = img
413
+
414
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
415
+
416
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
417
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
418
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
419
+
420
+ if seed == -1:
421
+ seed = random.randint(0, 65535)
422
+ seed_everything(seed)
423
+
424
+ if config.save_memory:
425
+ model.low_vram_shift(is_diffusing=False)
426
+ task = 'seg'
427
+ task_dic = {}
428
+ task_dic['name'] = task_to_name[task]
429
+ task_instruction = name_to_instruction[task_dic['name']]
430
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
431
+
432
+ cond = {"c_concat": [control],
433
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
434
+ "task": task_dic}
435
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
436
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
437
+ shape = (4, H // 8, W // 8)
438
+
439
+ if config.save_memory:
440
+ model.low_vram_shift(is_diffusing=True)
441
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
442
+ [strength] * 13)
443
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
444
+ shape, cond, verbose=False, eta=eta,
445
+ unconditional_guidance_scale=scale,
446
+ unconditional_conditioning=un_cond)
447
+
448
+ if config.save_memory:
449
+ model.low_vram_shift(is_diffusing=False)
450
+
451
+ x_samples = model.decode_first_stage(samples)
452
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
453
+ 255).astype(
454
+ np.uint8)
455
+
456
+ results = [x_samples[i] for i in range(num_samples)]
457
+ return [detected_map] + results
458
+
459
+
460
+ color_dict = {
461
+ 'background': (0, 0, 100),
462
+ 'person': (255, 0, 0),
463
+ 'bicycle': (0, 255, 0),
464
+ 'car': (0, 0, 255),
465
+ 'motorcycle': (255, 255, 0),
466
+ 'airplane': (255, 0, 255),
467
+ 'bus': (0, 255, 255),
468
+ 'train': (128, 128, 0),
469
+ 'truck': (128, 0, 128),
470
+ 'boat': (0, 128, 128),
471
+ 'traffic light': (128, 128, 128),
472
+ 'fire hydrant': (64, 0, 0),
473
+ 'stop sign': (0, 64, 0),
474
+ 'parking meter': (0, 0, 64),
475
+ 'bench': (64, 64, 0),
476
+ 'bird': (64, 0, 64),
477
+ 'cat': (0, 64, 64),
478
+ 'dog': (192, 192, 192),
479
+ 'horse': (32, 32, 32),
480
+ 'sheep': (96, 96, 96),
481
+ 'cow': (160, 160, 160),
482
+ 'elephant': (224, 224, 224),
483
+ 'bear': (32, 0, 0),
484
+ 'zebra': (0, 32, 0),
485
+ 'giraffe': (0, 0, 32),
486
+ 'backpack': (32, 32, 0),
487
+ 'umbrella': (32, 0, 32),
488
+ 'handbag': (0, 32, 32),
489
+ 'tie': (96, 0, 0),
490
+ 'suitcase': (0, 96, 0),
491
+ 'frisbee': (0, 0, 96),
492
+ 'skis': (96, 96, 0),
493
+ 'snowboard': (96, 0, 96),
494
+ 'sports ball': (0, 96, 96),
495
+ 'kite': (160, 0, 0),
496
+ 'baseball bat': (0, 160, 0),
497
+ 'baseball glove': (0, 0, 160),
498
+ 'skateboard': (160, 160, 0),
499
+ 'surfboard': (160, 0, 160),
500
+ 'tennis racket': (0, 160, 160),
501
+ 'bottle': (224, 0, 0),
502
+ 'wine glass': (0, 224, 0),
503
+ 'cup': (0, 0, 224),
504
+ 'fork': (224, 224, 0),
505
+ 'knife': (224, 0, 224),
506
+ 'spoon': (0, 224, 224),
507
+ 'bowl': (64, 64, 64),
508
+ 'banana': (128, 64, 64),
509
+ 'apple': (64, 128, 64),
510
+ 'sandwich': (64, 64, 128),
511
+ 'orange': (128, 128, 64),
512
+ 'broccoli': (128, 64, 128),
513
+ 'carrot': (64, 128, 128),
514
+ 'hot dog': (192, 64, 64),
515
+ 'pizza': (64, 192, 64),
516
+ 'donut': (64, 64, 192),
517
+ 'cake': (192, 192, 64),
518
+ 'chair': (192, 64, 192),
519
+ 'couch': (64, 192, 192),
520
+ 'potted plant': (96, 32, 32),
521
+ 'bed': (32, 96, 32),
522
+ 'dining table': (32, 32, 96),
523
+ 'toilet': (96, 96, 32),
524
+ 'tv': (96, 32, 96),
525
+ 'laptop': (32, 96, 96),
526
+ 'mouse': (160, 32, 32),
527
+ 'remote': (32, 160, 32),
528
+ 'keyboard': (32, 32, 160),
529
+ 'cell phone': (160, 160, 32),
530
+ 'microwave': (160, 32, 160),
531
+ 'oven': (32, 160, 160),
532
+ 'toaster': (224, 32, 32),
533
+ 'sink': (32, 224, 32),
534
+ 'refrigerator': (32, 32, 224),
535
+ 'book': (224, 224, 32),
536
+ 'clock': (224, 32, 224),
537
+ 'vase': (32, 224, 224),
538
+ 'scissors': (64, 96, 96),
539
+ 'teddy bear': (96, 64, 96),
540
+ 'hair drier': (96, 96, 64),
541
+ 'toothbrush': (160, 96, 96)
542
+ }
543
+
544
+
545
+ def process_bbox(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
546
+ strength, scale, seed, eta, confidence, nms_thresh, condition_mode):
547
+ with torch.no_grad():
548
+ input_image = HWC3(input_image)
549
+ img = resize_image(input_image, image_resolution)
550
+ H, W, C = img.shape
551
+
552
+ if condition_mode == True:
553
+ bbox, label, conf = cv.detect_common_objects(input_image, confidence=confidence, nms_thresh=nms_thresh)
554
+ mask = np.zeros((input_image.shape), np.uint8)
555
+ if len(bbox) > 0:
556
+ order_area = np.zeros(len(bbox))
557
+ # order_final = np.arange(len(bbox))
558
+ area_all = 0
559
+ for idx_mask, box in enumerate(bbox):
560
+ x_1, y_1, x_2, y_2 = box
561
+
562
+ x_1 = 0 if x_1 < 0 else x_1
563
+ y_1 = 0 if y_1 < 0 else y_1
564
+ x_2 = input_image.shape[1] if x_2 < 0 else x_2
565
+ y_2 = input_image.shape[0] if y_2 < 0 else y_2
566
+
567
+ area = (x_2 - x_1) * (y_2 - y_1)
568
+ order_area[idx_mask] = area
569
+ area_all += area
570
+ ordered_area = np.argsort(-order_area)
571
+
572
+ for idx_mask in ordered_area:
573
+ box = bbox[idx_mask]
574
+ x_1, y_1, x_2, y_2 = box
575
+ x_1 = 0 if x_1 < 0 else x_1
576
+ y_1 = 0 if y_1 < 0 else y_1
577
+ x_2 = input_image.shape[1] if x_2 < 0 else x_2
578
+ y_2 = input_image.shape[0] if y_2 < 0 else y_2
579
+
580
+ mask[y_1:y_2, x_1:x_2, :] = color_dict[label[idx_mask]]
581
+ detected_map = mask
582
+ else:
583
+ detected_map = img
584
+
585
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
586
+
587
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
588
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
589
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
590
+
591
+ if seed == -1:
592
+ seed = random.randint(0, 65535)
593
+ seed_everything(seed)
594
+
595
+ if config.save_memory:
596
+ model.low_vram_shift(is_diffusing=False)
597
+
598
+ task = 'bbox'
599
+ task_dic = {}
600
+ task_dic['name'] = task_to_name[task]
601
+ task_instruction = name_to_instruction[task_dic['name']]
602
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
603
+
604
+ cond = {"c_concat": [control],
605
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
606
+ "task": task_dic}
607
+
608
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
609
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
610
+ shape = (4, H // 8, W // 8)
611
+
612
+ if config.save_memory:
613
+ model.low_vram_shift(is_diffusing=True)
614
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
615
+ [strength] * 13)
616
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
617
+ shape, cond, verbose=False, eta=eta,
618
+ unconditional_guidance_scale=scale,
619
+ unconditional_conditioning=un_cond)
620
+
621
+ if config.save_memory:
622
+ model.low_vram_shift(is_diffusing=False)
623
+
624
+ x_samples = model.decode_first_stage(samples)
625
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
626
+ 255).astype(
627
+ np.uint8)
628
+
629
+ results = [x_samples[i] for i in range(num_samples)]
630
+ return [detected_map] + results
631
+
632
+
633
+ def process_outpainting(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
634
+ strength, scale, seed, eta, height_top_extended, height_down_extended, width_left_extended, width_right_extended, condition_mode):
635
+ with torch.no_grad():
636
+ input_image = HWC3(input_image)
637
+ img = resize_image(input_image, image_resolution)
638
+ H, W, C = img.shape
639
+ if condition_mode == True:
640
+ detected_map = outpainting(input_image, image_resolution, height_top_extended, height_down_extended, width_left_extended, width_right_extended)
641
+ else:
642
+ detected_map = img
643
+
644
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
645
+
646
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
647
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
648
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
649
+
650
+ if seed == -1:
651
+ seed = random.randint(0, 65535)
652
+ seed_everything(seed)
653
+
654
+ if config.save_memory:
655
+ model.low_vram_shift(is_diffusing=False)
656
+
657
+ task = 'outpainting'
658
+ task_dic = {}
659
+ task_dic['name'] = task_to_name[task]
660
+ task_instruction = name_to_instruction[task_dic['name']]
661
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
662
+
663
+ cond = {"c_concat": [control],
664
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
665
+ "task": task_dic}
666
+
667
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
668
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
669
+ shape = (4, H // 8, W // 8)
670
+
671
+ if config.save_memory:
672
+ model.low_vram_shift(is_diffusing=True)
673
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
674
+ [strength] * 13)
675
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
676
+ shape, cond, verbose=False, eta=eta,
677
+ unconditional_guidance_scale=scale,
678
+ unconditional_conditioning=un_cond)
679
+
680
+ if config.save_memory:
681
+ model.low_vram_shift(is_diffusing=False)
682
+
683
+ x_samples = model.decode_first_stage(samples)
684
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
685
+ 255).astype(
686
+ np.uint8)
687
+
688
+ results = [x_samples[i] for i in range(num_samples)]
689
+ return [detected_map] + results
690
+
691
+
692
+ def process_sketch(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
693
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode):
694
+ with torch.no_grad():
695
+ input_image = HWC3(input_image)
696
+ img = resize_image(input_image, image_resolution)
697
+ H, W, C = img.shape
698
+
699
+ if condition_mode == True:
700
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
701
+ detected_map = HWC3(detected_map)
702
+
703
+ # sketch the hed image
704
+ retry = 0
705
+ cnt = 0
706
+ while retry == 0:
707
+ threshold_value = np.random.randint(110, 160)
708
+ kernel_size = 3
709
+ alpha = 1.5
710
+ beta = 50
711
+ binary_image = cv2.threshold(detected_map, threshold_value, 255, cv2.THRESH_BINARY)[1]
712
+ inverted_image = cv2.bitwise_not(binary_image)
713
+ smoothed_image = cv2.GaussianBlur(inverted_image, (kernel_size, kernel_size), 0)
714
+ sketch_image = cv2.convertScaleAbs(smoothed_image, alpha=alpha, beta=beta)
715
+ if np.sum(sketch_image < 5) > 0.005 * sketch_image.shape[0] * sketch_image.shape[1] or cnt == 5:
716
+ retry = 1
717
+ else:
718
+ cnt += 1
719
+ detected_map = sketch_image
720
+ else:
721
+ detected_map = img
722
+
723
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
724
+
725
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
726
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
727
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
728
+
729
+ if seed == -1:
730
+ seed = random.randint(0, 65535)
731
+ seed_everything(seed)
732
+
733
+ if config.save_memory:
734
+ model.low_vram_shift(is_diffusing=False)
735
+
736
+ task = 'hedsketch'
737
+ task_dic = {}
738
+ task_dic['name'] = task_to_name[task]
739
+ task_instruction = name_to_instruction[task_dic['name']]
740
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
741
+
742
+ cond = {"c_concat": [control],
743
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
744
+ "task": task_dic}
745
+
746
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
747
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
748
+ shape = (4, H // 8, W // 8)
749
+
750
+ if config.save_memory:
751
+ model.low_vram_shift(is_diffusing=True)
752
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
753
+ [strength] * 13)
754
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
755
+ shape, cond, verbose=False, eta=eta,
756
+ unconditional_guidance_scale=scale,
757
+ unconditional_conditioning=un_cond)
758
+
759
+ if config.save_memory:
760
+ model.low_vram_shift(is_diffusing=False)
761
+
762
+ x_samples = model.decode_first_stage(samples)
763
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
764
+ 255).astype(
765
+ np.uint8)
766
+
767
+ results = [x_samples[i] for i in range(num_samples)]
768
+ return [detected_map] + results
769
+
770
+
771
+ def process_colorization(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
772
+ strength, scale, seed, eta, condition_mode):
773
+ with torch.no_grad():
774
+ input_image = HWC3(input_image)
775
+ img = resize_image(input_image, image_resolution)
776
+ H, W, C = img.shape
777
+ if condition_mode == True:
778
+ detected_map = grayscale(input_image, image_resolution)
779
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
780
+ detected_map = detected_map[:, :, np.newaxis]
781
+ detected_map = detected_map.repeat(3, axis=2)
782
+ else:
783
+ detected_map = img
784
+
785
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
786
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
787
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
788
+
789
+ if seed == -1:
790
+ seed = random.randint(0, 65535)
791
+ seed_everything(seed)
792
+
793
+ if config.save_memory:
794
+ model.low_vram_shift(is_diffusing=False)
795
+
796
+ task = 'grayscale'
797
+ task_dic = {}
798
+ task_dic['name'] = task_to_name[task]
799
+ task_instruction = name_to_instruction[task_dic['name']]
800
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
801
+
802
+ cond = {"c_concat": [control],
803
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
804
+ "task": task_dic}
805
+
806
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
807
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
808
+ shape = (4, H // 8, W // 8)
809
+
810
+ if config.save_memory:
811
+ model.low_vram_shift(is_diffusing=True)
812
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
813
+ [strength] * 13)
814
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
815
+ shape, cond, verbose=False, eta=eta,
816
+ unconditional_guidance_scale=scale,
817
+ unconditional_conditioning=un_cond)
818
+
819
+ if config.save_memory:
820
+ model.low_vram_shift(is_diffusing=False)
821
+
822
+ x_samples = model.decode_first_stage(samples)
823
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
824
+ 255).astype(
825
+ np.uint8)
826
+
827
+ results = [x_samples[i] for i in range(num_samples)]
828
+ return [detected_map] + results
829
+
830
+
831
+ def process_deblur(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
832
+ strength, scale, seed, eta, ksize, condition_mode):
833
+ with torch.no_grad():
834
+ input_image = HWC3(input_image)
835
+ img = resize_image(input_image, image_resolution)
836
+ H, W, C = img.shape
837
+ if condition_mode == True:
838
+ detected_map = blur(input_image, image_resolution, ksize)
839
+ else:
840
+ detected_map = img
841
+
842
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
843
+
844
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
845
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
846
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
847
+
848
+ if seed == -1:
849
+ seed = random.randint(0, 65535)
850
+ seed_everything(seed)
851
+
852
+ if config.save_memory:
853
+ model.low_vram_shift(is_diffusing=False)
854
+
855
+ task = 'blur'
856
+ task_dic = {}
857
+ task_dic['name'] = task_to_name[task]
858
+ task_instruction = name_to_instruction[task_dic['name']]
859
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
860
+
861
+ cond = {"c_concat": [control],
862
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
863
+ "task": task_dic}
864
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
865
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
866
+ shape = (4, H // 8, W // 8)
867
+
868
+ if config.save_memory:
869
+ model.low_vram_shift(is_diffusing=True)
870
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
871
+ [strength] * 13)
872
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
873
+ shape, cond, verbose=False, eta=eta,
874
+ unconditional_guidance_scale=scale,
875
+ unconditional_conditioning=un_cond)
876
+
877
+ if config.save_memory:
878
+ model.low_vram_shift(is_diffusing=False)
879
+
880
+ x_samples = model.decode_first_stage(samples)
881
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
882
+ 255).astype(
883
+ np.uint8)
884
+
885
+ results = [x_samples[i] for i in range(num_samples)]
886
+ return [detected_map] + results
887
+
888
+
889
+ def process_inpainting(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
890
+ strength, scale, seed, eta, h_ratio_t, h_ratio_d, w_ratio_l, w_ratio_r, condition_mode):
891
+ with torch.no_grad():
892
+ input_image = HWC3(input_image)
893
+ img = resize_image(input_image, image_resolution)
894
+ H, W, C = img.shape
895
+ if condition_mode == True:
896
+ detected_map = inpainting(input_image, image_resolution, h_ratio_t, h_ratio_d, w_ratio_l, w_ratio_r)
897
+ else:
898
+ detected_map = img
899
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
900
+
901
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
902
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
903
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
904
+
905
+ if seed == -1:
906
+ seed = random.randint(0, 65535)
907
+ seed_everything(seed)
908
+
909
+ if config.save_memory:
910
+ model.low_vram_shift(is_diffusing=False)
911
+
912
+ task = 'inpainting'
913
+ task_dic = {}
914
+ task_dic['name'] = task_to_name[task]
915
+ task_instruction = name_to_instruction[task_dic['name']]
916
+ task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :]
917
+
918
+ cond = {"c_concat": [control],
919
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
920
+ "task": task_dic}
921
+ un_cond = {"c_concat": [control * 0] if guess_mode else [control],
922
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
923
+ shape = (4, H // 8, W // 8)
924
+
925
+ if config.save_memory:
926
+ model.low_vram_shift(is_diffusing=True)
927
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
928
+ [strength] * 13)
929
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
930
+ shape, cond, verbose=False, eta=eta,
931
+ unconditional_guidance_scale=scale,
932
+ unconditional_conditioning=un_cond)
933
+
934
+ if config.save_memory:
935
+ model.low_vram_shift(is_diffusing=False)
936
+
937
+ x_samples = model.decode_first_stage(samples)
938
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
939
+ 255).astype(
940
+ np.uint8)
941
+
942
+ results = [x_samples[i] for i in range(num_samples)]
943
+ return [detected_map] + results
944
+
945
+
946
+ ############################################################################################################
947
+
948
+
949
+ demo = gr.Blocks()
950
+ with demo:
951
+ #gr.Markdown("UniControl Stable Diffusion Demo")
952
+ gr.HTML(
953
+ """
954
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
955
+ <h1 style="font-weight: 900; font-size: 2rem; margin: 0rem">
956
+ UniControl Stable Diffusion Demo
957
+ </h1>
958
+ <p style="font-size: 1rem; margin: 0rem">
959
+ Can Qin <sup>1,2</sup>, Shu Zhang<sup>1</sup>, Ning Yu <sup>1</sup>, Yihao Feng<sup>1</sup>, Xinyi Yang<sup>1</sup>, Yingbo Zhou <sup>1</sup>, Huan Wang <sup>1</sup>, Juan Carlos Niebles<sup>1</sup>, Caiming Xiong <sup>1</sup>, Silvio Savarese <sup>1</sup>, Stefano Ermon <sup>3</sup>, Yun Fu <sup>2</sup>, Ran Xu <sup>1</sup>
960
+ </p>
961
+ <p style="font-size: 0.8rem; margin: 0rem; line-height: 1em">
962
+ <sup>1</sup> Salesforce AI <sup>2</sup> Northeastern University <sup>3</sup> Stanford University
963
+ </p>
964
+ <p style="font-size: 0.8rem; margin: 0rem; line-height: 1em">
965
+ Work done when Can Qin was an intern at Salesforce AI Research.
966
+ </p>
967
+ <p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
968
+ <b> ONE compact model for ALL the visual-condition-to-image generation! </b>
969
+ <b><a href="https://github.com/salesforce/UniControl">[Github]</a></b>
970
+ <b><a href="https://canqin001.github.io/UniControl-Page/">[Website]</a></b>
971
+ <b><a href="https://arxiv.org/abs/2305.11147">[arXiv]</a></b>
972
+ </p>
973
+ </div>
974
+ """)
975
+
976
+ with gr.Tabs():
977
+ with gr.TabItem("Canny"):
978
+ with gr.Row():
979
+ gr.Markdown("## UniControl Stable Diffusion with Canny Edge Maps")
980
+ with gr.Row():
981
+ with gr.Column():
982
+ input_image = gr.Image(source='upload', type="numpy")
983
+ prompt = gr.Textbox(label="Prompt")
984
+ run_button = gr.Button(label="Run")
985
+ with gr.Accordion("Advanced options", open=False):
986
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
987
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
988
+ step=64)
989
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
990
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Canny', value=True)
991
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
992
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=40, step=1)
993
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200,
994
+ step=1)
995
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
996
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
997
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
998
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
999
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1000
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1001
+ with gr.Column():
1002
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1003
+ height='auto')
1004
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1005
+ strength, scale, seed, eta, low_threshold, high_threshold, condition_mode]
1006
+ run_button.click(fn=process_canny, inputs=ips, outputs=[result_gallery])
1007
+
1008
+ with gr.TabItem("HED"):
1009
+ with gr.Row():
1010
+ gr.Markdown("## UniControl Stable Diffusion with HED Maps")
1011
+ with gr.Row():
1012
+ with gr.Column():
1013
+ input_image = gr.Image(source='upload', type="numpy")
1014
+ prompt = gr.Textbox(label="Prompt")
1015
+ run_button = gr.Button(label="Run")
1016
+ with gr.Accordion("Advanced options", open=False):
1017
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1018
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1019
+ step=64)
1020
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1021
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> HED', value=True)
1022
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1023
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512,
1024
+ step=1)
1025
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1026
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1027
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1028
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1029
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1030
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1031
+ with gr.Column():
1032
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1033
+ height='auto')
1034
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1035
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1036
+ run_button.click(fn=process_hed, inputs=ips, outputs=[result_gallery])
1037
+
1038
+ with gr.TabItem("Sketch"):
1039
+ with gr.Row():
1040
+ gr.Markdown("## UniControl Stable Diffusion with Sketch Maps")
1041
+ with gr.Row():
1042
+ with gr.Column():
1043
+ input_image = gr.Image(source='upload', type="numpy")
1044
+ prompt = gr.Textbox(label="Prompt")
1045
+ run_button = gr.Button(label="Run")
1046
+ with gr.Accordion("Advanced options", open=False):
1047
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1048
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1049
+ step=64)
1050
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1051
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Sketch', value=False)
1052
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1053
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512,
1054
+ step=1)
1055
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1056
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1057
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1058
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1059
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
1060
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1061
+ with gr.Column():
1062
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1063
+ height='auto')
1064
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1065
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1066
+ run_button.click(fn=process_sketch, inputs=ips, outputs=[result_gallery])
1067
+
1068
+ with gr.TabItem("Depth"):
1069
+ with gr.Row():
1070
+ gr.Markdown("## UniControl Stable Diffusion with Depth Maps")
1071
+ with gr.Row():
1072
+ with gr.Column():
1073
+ input_image = gr.Image(source='upload', type="numpy")
1074
+ prompt = gr.Textbox(label="Prompt")
1075
+ run_button = gr.Button(label="Run")
1076
+ with gr.Accordion("Advanced options", open=False):
1077
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1078
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1079
+ step=64)
1080
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1081
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Depth', value=True)
1082
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1083
+ detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384,
1084
+ step=1)
1085
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1086
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1087
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1088
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1089
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1090
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1091
+ with gr.Column():
1092
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1093
+ height='auto')
1094
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1095
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1096
+ run_button.click(fn=process_depth, inputs=ips, outputs=[result_gallery])
1097
+
1098
+ with gr.TabItem("Normal"):
1099
+ with gr.Row():
1100
+ gr.Markdown("## UniControl Stable Diffusion with Normal Surface")
1101
+ with gr.Row():
1102
+ with gr.Column():
1103
+ input_image = gr.Image(source='upload', type="numpy")
1104
+ prompt = gr.Textbox(label="Prompt")
1105
+ run_button = gr.Button(label="Run")
1106
+ with gr.Accordion("Advanced options", open=False):
1107
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1108
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1109
+ step=64)
1110
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1111
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Normal', value=True)
1112
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1113
+ detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384,
1114
+ step=1)
1115
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1116
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1117
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1118
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1119
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1120
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1121
+ with gr.Column():
1122
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1123
+ height='auto')
1124
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1125
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1126
+ run_button.click(fn=process_normal, inputs=ips, outputs=[result_gallery])
1127
+
1128
+ with gr.TabItem("Human Pose"):
1129
+ with gr.Row():
1130
+ gr.Markdown("## UniControl Stable Diffusion with Human Pose")
1131
+ with gr.Row():
1132
+ with gr.Column():
1133
+ input_image = gr.Image(source='upload', type="numpy")
1134
+ prompt = gr.Textbox(label="Prompt")
1135
+ run_button = gr.Button(label="Run")
1136
+ with gr.Accordion("Advanced options", open=False):
1137
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1138
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1139
+ step=64)
1140
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1141
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Skeleton', value=True)
1142
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1143
+ detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512,
1144
+ step=1)
1145
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1146
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1147
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1148
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1149
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1150
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1151
+ with gr.Column():
1152
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1153
+ height='auto')
1154
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1155
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1156
+ run_button.click(fn=process_pose, inputs=ips, outputs=[result_gallery])
1157
+
1158
+ with gr.TabItem("Segmentation"):
1159
+ with gr.Row():
1160
+ gr.Markdown("## UniControl Stable Diffusion with Segmentation Maps (ADE20K)")
1161
+ with gr.Row():
1162
+ with gr.Column():
1163
+ input_image = gr.Image(source='upload', type="numpy")
1164
+ prompt = gr.Textbox(label="Prompt")
1165
+ run_button = gr.Button(label="Run")
1166
+ with gr.Accordion("Advanced options", open=False):
1167
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1168
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1169
+ step=64)
1170
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1171
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Seg', value=True)
1172
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1173
+ detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024,
1174
+ value=512, step=1)
1175
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1176
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1177
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1178
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1179
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1180
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1181
+ with gr.Column():
1182
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1183
+ height='auto')
1184
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution,
1185
+ ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode]
1186
+ run_button.click(fn=process_seg, inputs=ips, outputs=[result_gallery])
1187
+
1188
+ with gr.TabItem("Bbox"):
1189
+ with gr.Row():
1190
+ gr.Markdown("## UniControl Stable Diffusion with Object Bounding Boxes (MS-COCO)")
1191
+ with gr.Row():
1192
+ with gr.Column():
1193
+ input_image = gr.Image(source='upload', type="numpy")
1194
+ prompt = gr.Textbox(label="Prompt")
1195
+ run_button = gr.Button(label="Run")
1196
+ with gr.Accordion("Advanced options", open=False):
1197
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1198
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1199
+ step=64)
1200
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1201
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Bbox', value=True)
1202
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1203
+ confidence = gr.Slider(label="Confidence of Detection", minimum=0.1, maximum=1.0, value=0.4,
1204
+ step=0.1)
1205
+ nms_thresh = gr.Slider(label="Nms Threshold", minimum=0.1, maximum=1.0, value=0.5, step=0.1)
1206
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1207
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1208
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1209
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1210
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, bright')
1211
+ n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
1212
+ with gr.Column():
1213
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1214
+ height='auto')
1215
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1216
+ strength, scale, seed, eta, confidence, nms_thresh, condition_mode]
1217
+ run_button.click(fn=process_bbox, inputs=ips, outputs=[result_gallery])
1218
+
1219
+ with gr.TabItem("Outpainting"):
1220
+ with gr.Row():
1221
+ gr.Markdown("## UniControl Stable Diffusion with Image Outpainting")
1222
+ with gr.Row():
1223
+ with gr.Column():
1224
+ input_image = gr.Image(source='upload', type="numpy")
1225
+ prompt = gr.Textbox(label="Prompt")
1226
+ run_button = gr.Button(label="Run")
1227
+ with gr.Accordion("Advanced options", open=False):
1228
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1229
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1230
+ step=64)
1231
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1232
+ condition_mode = gr.Checkbox(label='Condition Extraction: Extending', value=False)
1233
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1234
+
1235
+ height_top_extended = gr.Slider(label="Top Extended Ratio (%)", minimum=1, maximum=200,
1236
+ value=50, step=1)
1237
+ height_down_extended = gr.Slider(label="Down Extended Ratio (%)", minimum=1, maximum=200,
1238
+ value=50, step=1)
1239
+
1240
+ width_left_extended = gr.Slider(label="Left Extended Ratio (%)", minimum=1, maximum=200,
1241
+ value=50, step=1)
1242
+ width_right_extended = gr.Slider(label="Right Extended Ratio (%)", minimum=1, maximum=200,
1243
+ value=50, step=1)
1244
+
1245
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1246
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1247
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1248
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1249
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
1250
+ n_prompt = gr.Textbox(label="Negative Prompt", value='')
1251
+ with gr.Column():
1252
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1253
+ height='auto')
1254
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1255
+ strength, scale, seed, eta, height_top_extended, height_down_extended, width_left_extended, width_right_extended, condition_mode]
1256
+ run_button.click(fn=process_outpainting, inputs=ips, outputs=[result_gallery])
1257
+
1258
+ with gr.TabItem("Inpainting"):
1259
+ with gr.Row():
1260
+ gr.Markdown("## UniControl Stable Diffusion with Image Inpainting")
1261
+ with gr.Row():
1262
+ with gr.Column():
1263
+ input_image = gr.Image(source='upload', type="numpy")
1264
+ prompt = gr.Textbox(label="Prompt")
1265
+ run_button = gr.Button(label="Run")
1266
+ with gr.Accordion("Advanced options", open=False):
1267
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1268
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1269
+ step=64)
1270
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1271
+ condition_mode = gr.Checkbox(label='Condition Extraction: Cropped Masking', value=False)
1272
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1273
+ h_ratio_t = gr.Slider(label="Top Masking Ratio (%)", minimum=0, maximum=100, value=30,
1274
+ step=1)
1275
+ h_ratio_d = gr.Slider(label="Down Masking Ratio (%)", minimum=0, maximum=100, value=60,
1276
+ step=1)
1277
+ w_ratio_l = gr.Slider(label="Left Masking Ratio (%)", minimum=0, maximum=100, value=30,
1278
+ step=1)
1279
+ w_ratio_r = gr.Slider(label="Right Masking Ratio (%)", minimum=0, maximum=100, value=60,
1280
+ step=1)
1281
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1282
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1283
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1284
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1285
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
1286
+ n_prompt = gr.Textbox(label="Negative Prompt", value='')
1287
+ with gr.Column():
1288
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1289
+ height='auto')
1290
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1291
+ strength, scale, seed, eta, h_ratio_t, h_ratio_d, w_ratio_l, w_ratio_r, condition_mode]
1292
+ run_button.click(fn=process_inpainting, inputs=ips, outputs=[result_gallery])
1293
+
1294
+ with gr.TabItem("Colorization"):
1295
+ with gr.Row():
1296
+ gr.Markdown("## UniControl Stable Diffusion with Gray Image Colorization")
1297
+ with gr.Row():
1298
+ with gr.Column():
1299
+ input_image = gr.Image(source='upload', type="numpy")
1300
+ prompt = gr.Textbox(label="Prompt")
1301
+ run_button = gr.Button(label="Run")
1302
+ with gr.Accordion("Advanced options", open=False):
1303
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1304
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1305
+ step=64)
1306
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1307
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Gray', value=False)
1308
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1309
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1310
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1311
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1312
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1313
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, colorful')
1314
+ n_prompt = gr.Textbox(label="Negative Prompt", value='')
1315
+ with gr.Column():
1316
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1317
+ height='auto')
1318
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1319
+ strength, scale, seed, eta, condition_mode]
1320
+ run_button.click(fn=process_colorization, inputs=ips, outputs=[result_gallery])
1321
+
1322
+ with gr.TabItem("Deblurring"):
1323
+ with gr.Row():
1324
+ gr.Markdown("## UniControl Stable Diffusion with Image Deblurring")
1325
+ with gr.Row():
1326
+ with gr.Column():
1327
+ input_image = gr.Image(source='upload', type="numpy")
1328
+ prompt = gr.Textbox(label="Prompt")
1329
+ run_button = gr.Button(label="Run")
1330
+ with gr.Accordion("Advanced options", open=False):
1331
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1332
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512,
1333
+ step=64)
1334
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
1335
+ condition_mode = gr.Checkbox(label='Condition Extraction: RGB -> Blur', value=False)
1336
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
1337
+ ksize = gr.Slider(label="Kernel Size", minimum=11, maximum=101, value=51, step=2)
1338
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1339
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
1340
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
1341
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
1342
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
1343
+ n_prompt = gr.Textbox(label="Negative Prompt", value='')
1344
+ with gr.Column():
1345
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
1346
+ height='auto')
1347
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode,
1348
+ strength, scale, seed, eta, ksize, condition_mode]
1349
+ run_button.click(fn=process_deblur, inputs=ips, outputs=[result_gallery])
1350
+
1351
+
1352
+ gr.Markdown('''### Tips
1353
+ - Please pay attention to <u> Condition Extraction </u> option.
1354
+ - Positive prompts and negative prompts are very useful sometimes.
1355
+ ''')
1356
+ gr.Markdown('''### Related Spaces
1357
+ - https://huggingface.co/spaces/hysts/ControlNet
1358
+ - https://huggingface.co/spaces/shi-labs/Prompt-Free-Diffusion
1359
+ ''')
1360
+ demo.launch()
annotator/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
annotator/__pycache__/util.cpython-38.pyc ADDED
Binary file (1.6 kB). View file
 
annotator/blur/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ class Blurrer:
4
+ def __call__(self, img, ksize):
5
+ img_new = cv2.GaussianBlur(img, (ksize, ksize), cv2.BORDER_DEFAULT)
6
+ img_new = img_new.astype('ubyte')
7
+ return img_new
annotator/blur/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (550 Bytes). View file
 
annotator/blur/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (507 Bytes). View file
 
annotator/canny/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ import cv2
12
+
13
+
14
+ class CannyDetector:
15
+ def __call__(self, img, low_threshold, high_threshold):
16
+ return cv2.Canny(img, low_threshold, high_threshold)
annotator/canny/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (890 Bytes). View file
 
annotator/canny/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (847 Bytes). View file
 
annotator/ckpts/ckpts.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Weights here.
annotator/grayscale/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from skimage import color
2
+
3
+ class GrayscaleConverter:
4
+ def __call__(self, img):
5
+ return (color.rgb2gray(img) * 255.0).astype('ubyte')
annotator/grayscale/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (554 Bytes). View file
 
annotator/grayscale/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (511 Bytes). View file
 
annotator/hed/__init__.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ # This is an improved version and model of HED edge detection without GPL contamination
12
+ # Please use this implementation in your products
13
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
14
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
15
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
16
+ # and in this way it works better for gradio's RGB protocol
17
+
18
+ import os
19
+ import cv2
20
+ import torch
21
+ import numpy as np
22
+
23
+ from einops import rearrange
24
+ from annotator.util import annotator_ckpts_path
25
+
26
+
27
+ class DoubleConvBlock(torch.nn.Module):
28
+ def __init__(self, input_channel, output_channel, layer_number):
29
+ super().__init__()
30
+ self.convs = torch.nn.Sequential()
31
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
32
+ for i in range(1, layer_number):
33
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
34
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
35
+
36
+ def __call__(self, x, down_sampling=False):
37
+ h = x
38
+ if down_sampling:
39
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
40
+ for conv in self.convs:
41
+ h = conv(h)
42
+ h = torch.nn.functional.relu(h)
43
+ return h, self.projection(h)
44
+
45
+
46
+ class ControlNetHED_Apache2(torch.nn.Module):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
50
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
51
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
52
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
53
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
54
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
55
+
56
+ def __call__(self, x):
57
+ h = x - self.norm
58
+ h, projection1 = self.block1(h)
59
+ h, projection2 = self.block2(h, down_sampling=True)
60
+ h, projection3 = self.block3(h, down_sampling=True)
61
+ h, projection4 = self.block4(h, down_sampling=True)
62
+ h, projection5 = self.block5(h, down_sampling=True)
63
+ return projection1, projection2, projection3, projection4, projection5
64
+
65
+
66
+ class HEDdetector:
67
+ def __init__(self):
68
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
69
+ modelpath = remote_model_path
70
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
71
+ if not os.path.exists(modelpath):
72
+ from basicsr.utils.download_util import load_file_from_url
73
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
74
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
75
+ self.netNetwork.load_state_dict(torch.load(modelpath))
76
+
77
+ def __call__(self, input_image):
78
+ assert input_image.ndim == 3
79
+ H, W, C = input_image.shape
80
+ with torch.no_grad():
81
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
82
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
83
+ edges = self.netNetwork(image_hed)
84
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
85
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
86
+ edges = np.stack(edges, axis=2)
87
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
88
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
89
+ return edge
90
+
91
+
92
+ def nms(x, t, s):
93
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
94
+
95
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
96
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
97
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
98
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
99
+
100
+ y = np.zeros_like(x)
101
+
102
+ for f in [f1, f2, f3, f4]:
103
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
104
+
105
+ z = np.zeros_like(y, dtype=np.uint8)
106
+ z[y > t] = 255
107
+ return z
annotator/hed/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
annotator/hed/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
 
annotator/inpainting/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Inpainter:
4
+ def __call__(self, img, height_top_mask, height_down_mask, width_left_mask, width_right_mask):
5
+ h = img.shape[0]
6
+ w = img.shape[1]
7
+ h_top_mask = int(float(h) / 100.0 * float(height_top_mask))
8
+ h_down_mask = int(float(h) / 100.0 * float(height_down_mask))
9
+
10
+ w_left_mask = int(float(w) / 100.0 * float(width_left_mask))
11
+ w_right_mask = int(float(w) / 100.0 * float(width_right_mask))
12
+
13
+ img_new = img
14
+ img_new[h_top_mask:h_down_mask, w_left_mask:w_right_mask] = 0
15
+ img_new = img_new.astype('ubyte')
16
+ return img_new
annotator/inpainting/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (824 Bytes). View file
 
annotator/inpainting/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (764 Bytes). View file
 
annotator/midas/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/midas/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ # Midas Depth Estimation
12
+ # From https://github.com/isl-org/MiDaS
13
+ # MIT LICENSE
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+
19
+ from einops import rearrange
20
+ from .api import MiDaSInference
21
+
22
+
23
+ class MidasDetector:
24
+ def __init__(self):
25
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
26
+
27
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
28
+ assert input_image.ndim == 3
29
+ image_depth = input_image
30
+ with torch.no_grad():
31
+ image_depth = torch.from_numpy(image_depth).float().cuda()
32
+ image_depth = image_depth / 127.5 - 1.0
33
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
34
+ depth = self.model(image_depth)[0]
35
+
36
+ depth_pt = depth.clone()
37
+ depth_pt -= torch.min(depth_pt)
38
+ depth_pt /= torch.max(depth_pt)
39
+ depth_pt = depth_pt.cpu().numpy()
40
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
41
+
42
+ depth_np = depth.cpu().numpy()
43
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
44
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
45
+ z = np.ones_like(x) * a
46
+ x[depth_pt < bg_th] = 0
47
+ y[depth_pt < bg_th] = 0
48
+ normal = np.stack([x, y, z], axis=2)
49
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
50
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
51
+
52
+ return depth_image, normal_image
annotator/midas/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
annotator/midas/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.92 kB). View file
 
annotator/midas/__pycache__/api.cpython-310.pyc ADDED
Binary file (4.1 kB). View file
 
annotator/midas/__pycache__/api.cpython-38.pyc ADDED
Binary file (4.14 kB). View file
 
annotator/midas/api.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ # based on https://github.com/isl-org/MiDaS
12
+
13
+ import cv2
14
+ import os
15
+ import torch
16
+ import torch.nn as nn
17
+ from torchvision.transforms import Compose
18
+
19
+ from .midas.dpt_depth import DPTDepthModel
20
+ from .midas.midas_net import MidasNet
21
+ from .midas.midas_net_custom import MidasNet_small
22
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
23
+ from annotator.util import annotator_ckpts_path
24
+
25
+
26
+ ISL_PATHS = {
27
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large_384.pt"),
28
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
29
+ "midas_v21": "",
30
+ "midas_v21_small": "",
31
+ }
32
+
33
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
34
+ # remote_model_path = "https://storage.googleapis.com/sfr-unicontrol-data-research/annotator/ckpts/dpt_large_384.pt" #"https://huggingface.co/Salesforce/UniControl/blob/main/annotator/ckpts/dpt_large_384.pt"
35
+
36
+ def disabled_train(self, mode=True):
37
+ """Overwrite model.train with this function to make sure train/eval mode
38
+ does not change anymore."""
39
+ return self
40
+
41
+
42
+ def load_midas_transform(model_type):
43
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
44
+ # load transform only
45
+ if model_type == "dpt_large": # DPT-Large
46
+ net_w, net_h = 384, 384
47
+ resize_mode = "minimal"
48
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
49
+
50
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
51
+ net_w, net_h = 384, 384
52
+ resize_mode = "minimal"
53
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
54
+
55
+ elif model_type == "midas_v21":
56
+ net_w, net_h = 384, 384
57
+ resize_mode = "upper_bound"
58
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
59
+
60
+ elif model_type == "midas_v21_small":
61
+ net_w, net_h = 256, 256
62
+ resize_mode = "upper_bound"
63
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
64
+
65
+ else:
66
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
67
+
68
+ transform = Compose(
69
+ [
70
+ Resize(
71
+ net_w,
72
+ net_h,
73
+ resize_target=None,
74
+ keep_aspect_ratio=True,
75
+ ensure_multiple_of=32,
76
+ resize_method=resize_mode,
77
+ image_interpolation_method=cv2.INTER_CUBIC,
78
+ ),
79
+ normalization,
80
+ PrepareForNet(),
81
+ ]
82
+ )
83
+
84
+ return transform
85
+
86
+
87
+ def load_model(model_type):
88
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
89
+ # load network
90
+ model_path = ISL_PATHS[model_type]
91
+ if model_type == "dpt_large": # DPT-Large
92
+ if not os.path.exists(model_path):
93
+ from basicsr.utils.download_util import load_file_from_url
94
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
95
+ #model_path = remote_model_path
96
+ model = DPTDepthModel(
97
+ path=model_path,
98
+ backbone="vitl16_384",
99
+ non_negative=True,
100
+ )
101
+ net_w, net_h = 384, 384
102
+ resize_mode = "minimal"
103
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
104
+
105
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
106
+ if not os.path.exists(model_path):
107
+ from basicsr.utils.download_util import load_file_from_url
108
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
109
+
110
+ model = DPTDepthModel(
111
+ path=model_path,
112
+ backbone="vitb_rn50_384",
113
+ non_negative=True,
114
+ )
115
+ net_w, net_h = 384, 384
116
+ resize_mode = "minimal"
117
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
118
+
119
+ elif model_type == "midas_v21":
120
+ model = MidasNet(model_path, non_negative=True)
121
+ net_w, net_h = 384, 384
122
+ resize_mode = "upper_bound"
123
+ normalization = NormalizeImage(
124
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
125
+ )
126
+
127
+ elif model_type == "midas_v21_small":
128
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
129
+ non_negative=True, blocks={'expand': True})
130
+ net_w, net_h = 256, 256
131
+ resize_mode = "upper_bound"
132
+ normalization = NormalizeImage(
133
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
134
+ )
135
+
136
+ else:
137
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
138
+ assert False
139
+
140
+ transform = Compose(
141
+ [
142
+ Resize(
143
+ net_w,
144
+ net_h,
145
+ resize_target=None,
146
+ keep_aspect_ratio=True,
147
+ ensure_multiple_of=32,
148
+ resize_method=resize_mode,
149
+ image_interpolation_method=cv2.INTER_CUBIC,
150
+ ),
151
+ normalization,
152
+ PrepareForNet(),
153
+ ]
154
+ )
155
+
156
+ return model.eval(), transform
157
+
158
+
159
+ class MiDaSInference(nn.Module):
160
+ MODEL_TYPES_TORCH_HUB = [
161
+ "DPT_Large",
162
+ "DPT_Hybrid",
163
+ "MiDaS_small"
164
+ ]
165
+ MODEL_TYPES_ISL = [
166
+ "dpt_large",
167
+ "dpt_hybrid",
168
+ "midas_v21",
169
+ "midas_v21_small",
170
+ ]
171
+
172
+ def __init__(self, model_type):
173
+ super().__init__()
174
+ assert (model_type in self.MODEL_TYPES_ISL)
175
+ model, _ = load_model(model_type)
176
+ self.model = model
177
+ self.model.train = disabled_train
178
+
179
+ def forward(self, x):
180
+ with torch.no_grad():
181
+ prediction = self.model(x)
182
+ return prediction
183
+
annotator/midas/midas/__init__.py ADDED
File without changes
annotator/midas/midas/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
annotator/midas/midas/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (126 Bytes). View file
 
annotator/midas/midas/__pycache__/base_model.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
annotator/midas/midas/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
annotator/midas/midas/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (7.58 kB). View file
 
annotator/midas/midas/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (7.72 kB). View file
 
annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
annotator/midas/midas/__pycache__/dpt_depth.cpython-38.pyc ADDED
Binary file (3.21 kB). View file
 
annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
annotator/midas/midas/__pycache__/midas_net.cpython-38.pyc ADDED
Binary file (2.67 kB). View file
 
annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
annotator/midas/midas/__pycache__/midas_net_custom.cpython-38.pyc ADDED
Binary file (3.79 kB). View file
 
annotator/midas/midas/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (6.05 kB). View file
 
annotator/midas/midas/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (6.05 kB). View file
 
annotator/midas/midas/__pycache__/vit.cpython-310.pyc ADDED
Binary file (9.74 kB). View file
 
annotator/midas/midas/__pycache__/vit.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
annotator/midas/midas/base_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ import torch
12
+
13
+
14
+ class BaseModel(torch.nn.Module):
15
+ def load(self, path):
16
+ """Load model from file.
17
+
18
+ Args:
19
+ path (str): file path
20
+ """
21
+ parameters = torch.load(path, map_location=torch.device('cpu'))
22
+
23
+ if "optimizer" in parameters:
24
+ parameters = parameters["model"]
25
+
26
+ self.load_state_dict(parameters)