yibolu
commited on
Commit
·
3308ae3
1
Parent(s):
6eca12e
update ipadapter
Browse files
lyrasd_model/module/lyrasd_ip_adapter.py
CHANGED
|
@@ -45,17 +45,11 @@ class LyraIPAdapter:
|
|
| 45 |
image_encoder_path=None,
|
| 46 |
num_ip_tokens=4,
|
| 47 |
ip_projection_dim=None,
|
| 48 |
-
fp_ckpt=None,
|
| 49 |
-
num_fp_tokens=1,
|
| 50 |
-
fp_projection_dim=None,
|
| 51 |
):
|
| 52 |
self.pipe = sd_pipe
|
| 53 |
self.device = device
|
| 54 |
-
self.fp_ckpt = fp_ckpt
|
| 55 |
self.ip_ckpt = ip_ckpt
|
| 56 |
-
self.num_fp_tokens = num_fp_tokens
|
| 57 |
self.num_ip_tokens = num_ip_tokens
|
| 58 |
-
self.fp_projection_dim = fp_projection_dim
|
| 59 |
self.ip_projection_dim = ip_projection_dim
|
| 60 |
self.sdxl = sdxl
|
| 61 |
self.ip_plus = ip_plus
|
|
@@ -76,10 +70,6 @@ class LyraIPAdapter:
|
|
| 76 |
else:
|
| 77 |
self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
|
| 78 |
|
| 79 |
-
# face proj model
|
| 80 |
-
if self.fp_ckpt:
|
| 81 |
-
self.face_proj_model = self.init_proj(self.fp_projection_dim, self.num_fp_tokens)
|
| 82 |
-
|
| 83 |
self.load_ip_adapter()
|
| 84 |
|
| 85 |
def init_proj_diffuser(self, state_dict):
|
|
@@ -131,16 +121,9 @@ class LyraIPAdapter:
|
|
| 131 |
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
|
| 132 |
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
| 133 |
unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
| 134 |
-
|
| 135 |
-
if self.fp_ckpt:
|
| 136 |
-
state_dict = torch.load(self.fp_ckpt, map_location="cpu")
|
| 137 |
-
self.face_proj_model.load_state_dict(state_dict["face_proj"])
|
| 138 |
-
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.fp_ckpt)
|
| 139 |
-
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
| 140 |
-
unet.load_facein(dir_ipadapter, "fp16")
|
| 141 |
|
| 142 |
@torch.inference_mode()
|
| 143 |
-
def get_image_embeds(self, image=None
|
| 144 |
image_prompt_embeds, uncond_image_prompt_embeds = None, None
|
| 145 |
|
| 146 |
if image is not None:
|
|
@@ -160,22 +143,11 @@ class LyraIPAdapter:
|
|
| 160 |
uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
| 161 |
image_prompt_embeds = clip_image_prompt_embeds
|
| 162 |
uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
|
| 163 |
-
|
| 164 |
-
if face_emb is not None:
|
| 165 |
-
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
| 166 |
-
face_prompt_embeds = self.face_proj_model(face_embeds)
|
| 167 |
-
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
| 168 |
-
if image_prompt_embeds is None:
|
| 169 |
-
image_prompt_embeds = face_prompt_embeds
|
| 170 |
-
uncond_image_prompt_embeds = uncond_face_prompt_embeds
|
| 171 |
-
else:
|
| 172 |
-
image_prompt_embeds = torch.cat([face_prompt_embeds, image_prompt_embeds], axis=1)
|
| 173 |
-
uncond_image_prompt_embeds = torch.cat([uncond_face_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
| 174 |
|
| 175 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 176 |
|
| 177 |
@torch.inference_mode()
|
| 178 |
-
def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None,
|
| 179 |
dict_tensor = {}
|
| 180 |
|
| 181 |
if self.ip_ckpt and ip_scale>0:
|
|
@@ -199,91 +171,4 @@ class LyraIPAdapter:
|
|
| 199 |
clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
|
| 200 |
ip_image_embeds = self.image_proj_model(clip_image_embeds)
|
| 201 |
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
| 202 |
-
|
| 203 |
-
if face_emb is not None and self.fp_ckpt and ip_scale>0:
|
| 204 |
-
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
| 205 |
-
face_prompt_embeds = self.face_proj_model(face_embeds)
|
| 206 |
-
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
| 207 |
-
if do_classifier_free_guidance:
|
| 208 |
-
fp_image_embeds = torch.cat([uncond_face_prompt_embeds, face_prompt_embeds])
|
| 209 |
-
else:
|
| 210 |
-
fp_image_embeds = face_prompt_embeds
|
| 211 |
-
dict_tensor["fp_hidden_states"] = fp_image_embeds
|
| 212 |
return dict_tensor
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
if __name__ == "__main__":
|
| 216 |
-
sys.path.append("/data/home/kiokaxiao/repos/LyraSD/python/lyrasd")
|
| 217 |
-
from lyrasd_model import LyraSdXLTxt2ImgPipeline
|
| 218 |
-
|
| 219 |
-
model_path = "/data/SharedModels/SD/checkpoints/stable-diffusion-xl-base-1.0/"
|
| 220 |
-
# model_path = "/cfs-datasets/projects/VirtualIdol/models/base_model/sdxl/xxmix9realisticsdxlV1"
|
| 221 |
-
lib_path = os.environ.get("LIBLYRASD_SO")
|
| 222 |
-
|
| 223 |
-
dir_ip_adapter = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
|
| 224 |
-
dir_facein = "/cfs-datasets/projects/VirtualIdol/models/FaceIn/v1/FaceIn_sdxl.bin"
|
| 225 |
-
image_encoder_path = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/models/image_encoder"
|
| 226 |
-
|
| 227 |
-
pipeline = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
|
| 228 |
-
pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, dir_facein, 1, 512)
|
| 229 |
-
# pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, "", 1, 512)
|
| 230 |
-
|
| 231 |
-
face_emb = np.load("/data/home/kiokaxiao/repos/VidolImageDraw/girl.npy")
|
| 232 |
-
face_emb = torch.Tensor(face_emb.reshape([1,-1]))
|
| 233 |
-
ip_image = Image.open("/data/home/kiokaxiao/repos/VidolImageDraw/images/input_image.png").convert('RGB')
|
| 234 |
-
|
| 235 |
-
generator = torch.Generator("cuda").manual_seed(123)
|
| 236 |
-
batches = [2]
|
| 237 |
-
sizes = [[512, 512], [768, 768], [1024, 1024]]
|
| 238 |
-
# sizes = [[832, 640]]
|
| 239 |
-
# sizes = [[1024, 1024]]
|
| 240 |
-
running_cnt = 1
|
| 241 |
-
do_bench = False
|
| 242 |
-
|
| 243 |
-
ip_ratio = 1
|
| 244 |
-
facein_ratio = 0.6
|
| 245 |
-
extra_tensor_dict = {}
|
| 246 |
-
extra_tensor_dict = pipeline.ip_adapter_helper.get_image_embeds_lyrasd(ip_image, None, face_emb, batches[0], ip_ratio, facein_ratio)
|
| 247 |
-
param_scale_dict = {"facein_ratio": facein_ratio, "ip_ratio": ip_ratio}
|
| 248 |
-
draw_cfg = {'width': 640,
|
| 249 |
-
'num_inference_steps': 30,
|
| 250 |
-
'height': 832,
|
| 251 |
-
'negative_prompt': '(worst quality, low quality, 3d, 2d, cartoons, sketch), tooth, open mouth',
|
| 252 |
-
'guidance_scale': 7,
|
| 253 |
-
'prompt': 'xxmixgirl, masterpiece, best quality, 1girl, solo, looking at viewer, simple background, hair ornament, black eyes, portrait',
|
| 254 |
-
'output_type': 'pil',
|
| 255 |
-
'extra_tensor_dict': extra_tensor_dict,
|
| 256 |
-
"param_scale_dict": param_scale_dict}
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
def warmup(draw_cfg):
|
| 260 |
-
draw_cfg_wm = deepcopy(draw_cfg)
|
| 261 |
-
draw_cfg_wm['num_inference_steps'] = 1
|
| 262 |
-
pipeline(**draw_cfg_wm, generator= generator)
|
| 263 |
-
|
| 264 |
-
if not do_bench:
|
| 265 |
-
images = pipeline(**draw_cfg, generator= generator)
|
| 266 |
-
else:
|
| 267 |
-
for batch in batches:
|
| 268 |
-
for height, width in sizes:
|
| 269 |
-
draw_cfg['width'] = width
|
| 270 |
-
draw_cfg['height'] = height
|
| 271 |
-
draw_cfg['num_images_per_prompt'] = batch
|
| 272 |
-
draw_cfg["num_inference_steps"] = 20
|
| 273 |
-
warmup(draw_cfg)
|
| 274 |
-
time_uses = []
|
| 275 |
-
for x in range(running_cnt):
|
| 276 |
-
start = time.perf_counter()
|
| 277 |
-
draw_cfg['num_images_per_prompt'] = batch
|
| 278 |
-
generator = torch.Generator("cuda").manual_seed(123)
|
| 279 |
-
print("draw_cfg: ", draw_cfg.keys())
|
| 280 |
-
print("draw_cfg: ", draw_cfg)
|
| 281 |
-
|
| 282 |
-
images = pipeline(**draw_cfg, generator= generator)
|
| 283 |
-
time_use = time.perf_counter() - start
|
| 284 |
-
time_uses.append(time_use)
|
| 285 |
-
print("bench", batch, width, sum(time_uses)/running_cnt, get_mem_use())
|
| 286 |
-
|
| 287 |
-
print(type(images))
|
| 288 |
-
images[0].save("t.png")
|
| 289 |
-
|
|
|
|
| 45 |
image_encoder_path=None,
|
| 46 |
num_ip_tokens=4,
|
| 47 |
ip_projection_dim=None,
|
|
|
|
|
|
|
|
|
|
| 48 |
):
|
| 49 |
self.pipe = sd_pipe
|
| 50 |
self.device = device
|
|
|
|
| 51 |
self.ip_ckpt = ip_ckpt
|
|
|
|
| 52 |
self.num_ip_tokens = num_ip_tokens
|
|
|
|
| 53 |
self.ip_projection_dim = ip_projection_dim
|
| 54 |
self.sdxl = sdxl
|
| 55 |
self.ip_plus = ip_plus
|
|
|
|
| 70 |
else:
|
| 71 |
self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
self.load_ip_adapter()
|
| 74 |
|
| 75 |
def init_proj_diffuser(self, state_dict):
|
|
|
|
| 121 |
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
|
| 122 |
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
| 123 |
unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
@torch.inference_mode()
|
| 126 |
+
def get_image_embeds(self, image=None):
|
| 127 |
image_prompt_embeds, uncond_image_prompt_embeds = None, None
|
| 128 |
|
| 129 |
if image is not None:
|
|
|
|
| 143 |
uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
| 144 |
image_prompt_embeds = clip_image_prompt_embeds
|
| 145 |
uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 148 |
|
| 149 |
@torch.inference_mode()
|
| 150 |
+
def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, batch_size = 1, ip_scale=1.0, do_classifier_free_guidance=True):
|
| 151 |
dict_tensor = {}
|
| 152 |
|
| 153 |
if self.ip_ckpt and ip_scale>0:
|
|
|
|
| 171 |
clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
|
| 172 |
ip_image_embeds = self.image_proj_model(clip_image_embeds)
|
| 173 |
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
return dict_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|