yamildiego commited on
Commit
158f9b4
1 Parent(s): 0b11b0c

rollback self.ip_ckpt

Browse files
Files changed (2) hide show
  1. handler.py +220 -12
  2. requirements.txt +16 -2
handler.py CHANGED
@@ -1,15 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class EndpointHandler():
2
- def __init__(self, path=""):
3
- pass
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def __call__(self, data):
6
- """
7
- data args:
8
- inputs (:obj: `str`)
9
- date (:obj: `str`)
10
- Return:
11
- A :obj:`list` | `dict`: will be serialized and returned
12
- """
13
- inputs = data.pop("inputs",data)
14
-
15
- return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from pathlib import Path
8
+
9
+ from huggingface_hub import hf_hub_download, snapshot_download
10
+ from ip_adapter.ip_adapter import IPAdapterXL
11
+ from safetensors.torch import load_file
12
+ import os
13
+
14
+ from diffusers import (
15
+ ControlNetModel,
16
+ StableDiffusionXLControlNetPipeline,
17
+ UNet2DConditionModel,
18
+ EulerDiscreteScheduler,
19
+ )
20
+
21
+ # global variable
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
25
+
26
+ # initialization
27
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
28
+ # image_encoder_path = "sdxl_models/image_encoder"
29
+ # ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
30
+ controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
31
+
32
+
33
+
34
  class EndpointHandler():
35
+ def __init__(self, model_dir):
36
+
37
+ repo_id = "h94/IP-Adapter"
38
+
39
+ # Descargar todo el contenido del directorio image_encoder
40
+ local_repo_path = snapshot_download(repo_id=repo_id)
41
+ # image_encoder_local_path = os.path.join(local_repo_path, "image_encoder")
42
+ self.image_encoder_local_path = os.path.join(local_repo_path, "sdxl_models", "image_encoder")
43
+ self.ip_ckpt = os.path.join(local_repo_path, "sdxl_models", "ip-adapter_sdxl.bin")
44
+
45
+
46
+ self.controlnet = ControlNetModel.from_pretrained(
47
+ controlnet_path, use_safetensors=False, torch_dtype=torch.float16
48
+ ).to(device)
49
 
50
+ # load SDXL lightnining
51
+
52
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
53
+ base_model_path,
54
+ controlnet=self.controlnet,
55
+ torch_dtype=torch.float16,
56
+ variant="fp16",
57
+ add_watermarker=False,
58
+ ).to(device)
59
+ self.pipe.set_progress_bar_config(disable=True)
60
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
61
+ self.pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
62
+ )
63
+ self.pipe.unet.load_state_dict(
64
+ load_file(
65
+ hf_hub_download(
66
+ "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
67
+ ),
68
+ device="cuda",
69
+ )
70
+ )
71
+
72
+ self.ip_model = IPAdapterXL(
73
+ self.pipe,
74
+ self.image_encoder_local_path,
75
+ self.ip_ckpt,
76
+ device,
77
+ target_blocks=["up_blocks.0.attentions.1"],
78
+ )
79
+
80
  def __call__(self, data):
81
+
82
+ def create_image(
83
+ image_pil,
84
+ input_image,
85
+ prompt,
86
+ n_prompt,
87
+ scale,
88
+ control_scale,
89
+ guidance_scale,
90
+ num_inference_steps,
91
+ seed,
92
+ target="Load only style blocks",
93
+ neg_content_prompt=None,
94
+ neg_content_scale=0,
95
+ ):
96
+ seed = random.randint(0, MAX_SEED) if seed == -1 else seed
97
+ if target == "Load original IP-Adapter":
98
+ # target_blocks=["blocks"] for original IP-Adapter
99
+ ip_model = IPAdapterXL(
100
+ self.pipe, self.image_encoder_local_path, self.ip_ckpt, device, target_blocks=["blocks"]
101
+ )
102
+ elif target == "Load only style blocks":
103
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
104
+ ip_model = IPAdapterXL(
105
+ self.pipe,
106
+ self.image_encoder_local_path,
107
+ self.ip_ckpt,
108
+ device,
109
+ target_blocks=["up_blocks.0.attentions.1"],
110
+ )
111
+ elif target == "Load style+layout block":
112
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
113
+ ip_model = IPAdapterXL(
114
+ self.pipe,
115
+ self.image_encoder_local_path,
116
+ self.ip_ckpt,
117
+ device,
118
+ target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
119
+ )
120
+
121
+ if input_image is not None:
122
+ input_image = resize_img(input_image, max_side=1024)
123
+ cv_input_image = pil_to_cv2(input_image)
124
+ detected_map = cv2.Canny(cv_input_image, 50, 200)
125
+ canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
126
+ else:
127
+ canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
128
+ control_scale = 0
129
+
130
+ if float(control_scale) == 0:
131
+ canny_map = canny_map.resize((1024, 1024))
132
+
133
+ if len(neg_content_prompt) > 0 and neg_content_scale != 0:
134
+ images = ip_model.generate(
135
+ pil_image=image_pil,
136
+ prompt=prompt,
137
+ negative_prompt=n_prompt,
138
+ scale=scale,
139
+ guidance_scale=guidance_scale,
140
+ num_samples=1,
141
+ num_inference_steps=num_inference_steps,
142
+ seed=seed,
143
+ image=canny_map,
144
+ controlnet_conditioning_scale=float(control_scale),
145
+ neg_content_prompt=neg_content_prompt,
146
+ neg_content_scale=neg_content_scale,
147
+ )
148
+ else:
149
+ images = ip_model.generate(
150
+ pil_image=image_pil,
151
+ prompt=prompt,
152
+ negative_prompt=n_prompt,
153
+ scale=scale,
154
+ guidance_scale=guidance_scale,
155
+ num_samples=1,
156
+ num_inference_steps=num_inference_steps,
157
+ seed=seed,
158
+ image=canny_map,
159
+ controlnet_conditioning_scale=float(control_scale),
160
+ )
161
+ image = images[0]
162
+
163
+ return image
164
+
165
+
166
+ def pil_to_cv2(image_pil):
167
+ image_np = np.array(image_pil)
168
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
169
+ return image_cv2
170
+
171
+ def resize_img(
172
+ input_image,
173
+ max_side=1280,
174
+ min_side=1024,
175
+ size=None,
176
+ pad_to_max_side=False,
177
+ mode=Image.BILINEAR,
178
+ base_pixel_number=64,
179
+ ):
180
+ w, h = input_image.size
181
+ if size is not None:
182
+ w_resize_new, h_resize_new = size
183
+ else:
184
+ ratio = min_side / min(h, w)
185
+ w, h = round(ratio * w), round(ratio * h)
186
+ ratio = max_side / max(h, w)
187
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
188
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
189
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
190
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
191
+
192
+ if pad_to_max_side:
193
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
194
+ offset_x = (max_side - w_resize_new) // 2
195
+ offset_y = (max_side - h_resize_new) // 2
196
+ res[
197
+ offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
198
+ ] = np.array(input_image)
199
+ input_image = Image.fromarray(res)
200
+ return input_image
201
+
202
+
203
+ style_image = "https://huggingface.co/spaces/radames/InstantStyle-SDXL-Lightning/resolve/main/assets/0.jpg"
204
+ source_image =None
205
+ prompt = "a cat, masterpiece, best quality, high quality"
206
+ scale =1.0
207
+ control_scale =0.0
208
+
209
+
210
+ return create_image(
211
+ image_pil=style_image,
212
+ input_image=source_image,
213
+ prompt=prompt,
214
+ n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
215
+ scale=scale,
216
+ control_scale=control_scale,
217
+ guidance_scale=0.0,
218
+ num_inference_steps=2,
219
+ seed=42,
220
+ target="Load only style blocks",
221
+ neg_content_prompt="",
222
+ neg_content_scale=0,
223
+ )
requirements.txt CHANGED
@@ -1,2 +1,16 @@
1
- transformers==4.18.0
2
- holidays==0.13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.27.2
2
+ torch>=2.0.0
3
+ torchvision>=0.15.1
4
+ transformers>=4.37.1
5
+ accelerate
6
+ safetensors
7
+ einops
8
+ spaces>=0.19.4
9
+ omegaconf
10
+ peft
11
+ huggingface-hub>=0.20.2
12
+ opencv-python
13
+ gradio
14
+ controlnet_aux
15
+ gdown
16
+ peft