InstantID-XS / utils /tools.py
XuDongZhou's picture
Upload 9 files
96bea52 verified
raw
history blame
4.54 kB
import os, json
import cv2
import glob
import numpy as np
from PIL import Image
import torch
def im_resize(original_image, short_len=1024):
h, w = original_image.shape[:-1]
if min(h, w) != short_len:
if h > w:
out_size = (short_len, int(h/w*short_len))
else:
out_size = (int(w/h*short_len), short_len)
else:
out_size = w, h
return cv2.resize(original_image, out_size)
def pixelize(image, block_size=64):
# 获取图像的宽度和高度
height, width, _ = image.shape
# 计算新图像的宽度和高度,使得每个块为 block_size x block_size 的大小
new_width = (width // block_size) * block_size
new_height = (height // block_size) * block_size
# 缩放图像以匹配新的宽度和高度
resized_image = cv2.resize(image, (new_width, new_height))
# 将图像分割成块并用块的平均值替代
for i in range(0, new_height, block_size):
for j in range(0, new_width, block_size):
block = resized_image[i:i+block_size, j:j+block_size, :]
average_color = np.mean(block, axis=(0, 1), dtype=int)
resized_image[i:i+block_size, j:j+block_size, :] = average_color
# 将图像缩小回原始大小,以增加像素风格的效果
final_image = cv2.resize(resized_image, (width, height))
return final_image
def get_kps_bbox_faceid(w, h, json_path):
def get_new_kps_and_bbox(w, h, kps, bbox):
scale = 512/max(w, h)
pad = abs(w - h) * scale / 2
if w < h:
kps[:, 0] -= pad
bbox[0] -= pad
bbox[2] -= pad
elif h < w:
kps[:, 1] -= pad
bbox[1] -= pad
bbox[3] -= pad
kps /= scale
bbox /= scale
return kps, bbox
with open(json_path, 'r') as file:
data = json.load(file)
kps = np.array(data.get("kps"))
bbox = np.array(data.get("bbox"))
kps, bbox = get_new_kps_and_bbox(w, h, kps, bbox)
embedding = data.get("embedding")
face_id_embed = embedding / np.linalg.norm(embedding)
face_id_embed = torch.from_numpy(face_id_embed)
return kps, bbox, face_id_embed
def get_kps_and_face_id_embed(w, h, json_path):
def get_new_kps(w, h, kps):
scale = 512/max(w, h)
pad = abs(w - h) * scale / 2
if w < h:
kps[:, 0] -= pad
elif h < w:
kps[:, 1] -= pad
kps = kps / scale
return kps
with open(json_path, 'r') as file:
data = json.load(file)
kps = np.array(data.get("kps"))
kps = get_new_kps(w, h, kps)
embedding = data.get("embedding")
face_id_embed = embedding / np.linalg.norm(embedding)
face_id_embed = torch.from_numpy(face_id_embed)
return kps, face_id_embed
def get_face_id_embed(json_path):
with open(json_path, 'r') as file:
data = json.load(file)
embedding = data.get("embedding")
face_id_embed = embedding / np.linalg.norm(embedding)
face_id_embed = torch.from_numpy(face_id_embed)
return face_id_embed
def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
kohya_ss_state_dict = {}
for peft_key, weight in module.items():
kohya_key = peft_key.replace("unet.base_model.model", prefix)
kohya_key = kohya_key.replace("lora_A", "lora_down")
kohya_key = kohya_key.replace("lora_B", "lora_up")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_ss_state_dict[kohya_key] = weight.to(dtype)
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)
return kohya_ss_state_dict
def get_module_kohya_state_dict_xs(module, dtype):
kohya_ss_state_dict = {}
for peft_key, weight in module.items():
if "mid_block" in peft_key:
peft_key = peft_key.replace('attentions', 'base_midblock.attentions')
elif "down_block" in peft_key:
peft_key = peft_key.replace('attentions', 'base_attentions')
if dtype == None:
kohya_ss_state_dict[peft_key] = weight
else:
kohya_ss_state_dict[peft_key] = weight.to(dtype)
return kohya_ss_state_dict