MuseVSpace / MuseV /musev /utils /model_util.py
anchorxia's picture
add musev
96d7ad8
import gc
import os
from typing import Any, Callable, List, Literal, Union, Dict, Tuple
import logging
from safetensors.torch import load_file
from safetensors import safe_open
import torch
from torch import nn
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from .convert_from_ckpt import (
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ldm_clip_checkpoint,
)
from .convert_lora_safetensor_to_diffusers import convert_motion_lora_ckpt_to_diffusers
logger = logging.getLogger(__name__)
def update_pipeline_model_parameters(
pipeline: DiffusionPipeline,
model_path: str = None,
lora_dict: Dict[str, Dict] = None,
text_model_path: str = None,
device="cuda",
need_unload: bool = False,
):
if model_path is not None:
pipeline = update_pipeline_basemodel(
pipeline, model_path, text_sd_model_path=text_model_path, device=device
)
if lora_dict is not None:
pipeline, unload_dict = update_pipeline_lora_models(
pipeline,
lora_dict,
device=device,
need_unload=need_unload,
)
if need_unload:
return pipeline, unload_dict
return pipeline
def update_pipeline_basemodel(
pipeline: DiffusionPipeline,
model_path: str,
text_sd_model_path: str,
device: str = "cuda",
):
"""使用model_path更新pipeline中的基础参数
Args:
pipeline (DiffusionPipeline): _description_
model_path (str): _description_
text_sd_model_path (str): _description_
device (str, optional): _description_. Defaults to "cuda".
Returns:
_type_: _description_
"""
# load base
if model_path.endswith(".ckpt"):
state_dict = torch.load(model_path, map_location=device)
pipeline.unet.load_state_dict(state_dict)
print("update sd_model", model_path)
elif model_path.endswith(".safetensors"):
base_state_dict = {}
with safe_open(model_path, framework="pt", device=device) as f:
for key in f.keys():
base_state_dict[key] = f.get_tensor(key)
is_lora = all("lora" in k for k in base_state_dict.keys())
assert is_lora == False, "Base model cannot be LoRA: {}".format(model_path)
# vae
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
base_state_dict, pipeline.vae.config
)
pipeline.vae.load_state_dict(converted_vae_checkpoint)
# unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
base_state_dict, pipeline.unet.config
)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
# text_model
pipeline.text_encoder = convert_ldm_clip_checkpoint(
base_state_dict, text_sd_model_path
)
print("update sd_model", model_path)
pipeline.to(device)
return pipeline
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/cfg.yaml
LORA_BLOCK_WEIGHT_MAP = {
"FACE": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
"DEFACE": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
"ALL": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"MIDD": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
"OUTALL": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
}
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py
def update_pipeline_lora_model(
pipeline: DiffusionPipeline,
lora: Union[str, Dict],
alpha: float = 0.75,
device: str = "cuda",
lora_prefix_unet: str = "lora_unet",
lora_prefix_text_encoder: str = "lora_te",
lora_unet_layers=[
"lora_unet_down_blocks_0_attentions_0",
"lora_unet_down_blocks_0_attentions_1",
"lora_unet_down_blocks_1_attentions_0",
"lora_unet_down_blocks_1_attentions_1",
"lora_unet_down_blocks_2_attentions_0",
"lora_unet_down_blocks_2_attentions_1",
"lora_unet_mid_block_attentions_0",
"lora_unet_up_blocks_1_attentions_0",
"lora_unet_up_blocks_1_attentions_1",
"lora_unet_up_blocks_1_attentions_2",
"lora_unet_up_blocks_2_attentions_0",
"lora_unet_up_blocks_2_attentions_1",
"lora_unet_up_blocks_2_attentions_2",
"lora_unet_up_blocks_3_attentions_0",
"lora_unet_up_blocks_3_attentions_1",
"lora_unet_up_blocks_3_attentions_2",
],
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL",
need_unload: bool = False,
):
"""使用 lora 更新pipeline中的unet相关参数
Args:
pipeline (DiffusionPipeline): _description_
lora (Union[str, Dict]): _description_
alpha (float, optional): _description_. Defaults to 0.75.
device (str, optional): _description_. Defaults to "cuda".
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet".
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te".
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ].
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL".
need_unload (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20
if lora_block_weight_str is not None:
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()]
if lora_block_weight:
assert len(lora_block_weight) == 17
# load lora weight
if isinstance(lora, str):
state_dict = load_file(lora, device=device)
else:
for k in lora:
lora[k] = lora[k].to(device)
state_dict = lora # state_dict = {}
visited = set()
unload_dict = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = (
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_")
)
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
alpha_key = key.replace("lora_down.weight", "alpha")
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
alpha_key = key.replace("lora_up.weight", "alpha")
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
)
if alpha_key in state_dict:
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1]
else:
weight_scale = 1.0
# adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
if len(weight_up.shape) == len(weight_down.shape):
adding_weight = (
alpha
* weight_scale
* torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
else:
adding_weight = (
alpha
* weight_scale
* torch.einsum("a b, b c h w -> a c h w", weight_up, weight_down)
)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
if alpha_key in state_dict:
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1]
else:
weight_scale = 1.0
adding_weight = alpha * weight_scale * torch.mm(weight_up, weight_down)
adding_weight = adding_weight.to(torch.float16)
if lora_block_weight:
if "text" in key:
adding_weight *= lora_block_weight[0]
else:
for idx, layer in enumerate(lora_unet_layers):
if layer in key:
adding_weight *= lora_block_weight[idx + 1]
break
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight}
curr_layer.weight.data += adding_weight
unload_dict.append(curr_layer_unload_data)
# update visited list
for item in pair_keys:
visited.add(item)
if need_unload:
return pipeline, unload_dict
else:
return pipeline
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py
def update_pipeline_lora_model_old(
pipeline: DiffusionPipeline,
lora: Union[str, Dict],
alpha: float = 0.75,
device: str = "cuda",
lora_prefix_unet: str = "lora_unet",
lora_prefix_text_encoder: str = "lora_te",
lora_unet_layers=[
"lora_unet_down_blocks_0_attentions_0",
"lora_unet_down_blocks_0_attentions_1",
"lora_unet_down_blocks_1_attentions_0",
"lora_unet_down_blocks_1_attentions_1",
"lora_unet_down_blocks_2_attentions_0",
"lora_unet_down_blocks_2_attentions_1",
"lora_unet_mid_block_attentions_0",
"lora_unet_up_blocks_1_attentions_0",
"lora_unet_up_blocks_1_attentions_1",
"lora_unet_up_blocks_1_attentions_2",
"lora_unet_up_blocks_2_attentions_0",
"lora_unet_up_blocks_2_attentions_1",
"lora_unet_up_blocks_2_attentions_2",
"lora_unet_up_blocks_3_attentions_0",
"lora_unet_up_blocks_3_attentions_1",
"lora_unet_up_blocks_3_attentions_2",
],
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL",
need_unload: bool = False,
):
"""使用 lora 更新pipeline中的unet相关参数
Args:
pipeline (DiffusionPipeline): _description_
lora (Union[str, Dict]): _description_
alpha (float, optional): _description_. Defaults to 0.75.
device (str, optional): _description_. Defaults to "cuda".
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet".
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te".
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ].
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL".
need_unload (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20
if lora_block_weight_str is not None:
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()]
if lora_block_weight:
assert len(lora_block_weight) == 17
# load lora weight
if isinstance(lora, str):
state_dict = load_file(lora, device=device)
else:
for k in lora:
lora[k] = lora[k].to(device)
state_dict = lora # state_dict = {}
visited = set()
unload_dict = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = (
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_")
)
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
)
adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(
2
).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
adding_weight = alpha * torch.mm(weight_up, weight_down)
if lora_block_weight:
if "text" in key:
adding_weight *= lora_block_weight[0]
else:
for idx, layer in enumerate(lora_unet_layers):
if layer in key:
adding_weight *= lora_block_weight[idx + 1]
break
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight}
curr_layer.weight.data += adding_weight
unload_dict.append(curr_layer_unload_data)
# update visited list
for item in pair_keys:
visited.add(item)
if need_unload:
return pipeline, unload_dict
else:
return pipeline
def update_pipeline_lora_models(
pipeline: DiffusionPipeline,
lora_dict: Dict[str, Dict],
device: str = "cuda",
need_unload: bool = True,
lora_prefix_unet: str = "lora_unet",
lora_prefix_text_encoder: str = "lora_te",
lora_unet_layers=[
"lora_unet_down_blocks_0_attentions_0",
"lora_unet_down_blocks_0_attentions_1",
"lora_unet_down_blocks_1_attentions_0",
"lora_unet_down_blocks_1_attentions_1",
"lora_unet_down_blocks_2_attentions_0",
"lora_unet_down_blocks_2_attentions_1",
"lora_unet_mid_block_attentions_0",
"lora_unet_up_blocks_1_attentions_0",
"lora_unet_up_blocks_1_attentions_1",
"lora_unet_up_blocks_1_attentions_2",
"lora_unet_up_blocks_2_attentions_0",
"lora_unet_up_blocks_2_attentions_1",
"lora_unet_up_blocks_2_attentions_2",
"lora_unet_up_blocks_3_attentions_0",
"lora_unet_up_blocks_3_attentions_1",
"lora_unet_up_blocks_3_attentions_2",
],
):
"""使用 lora 更新pipeline中的unet相关参数
Args:
pipeline (DiffusionPipeline): _description_
lora_dict (Dict[str, Dict]): _description_
device (str, optional): _description_. Defaults to "cuda".
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet".
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te".
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ].
Returns:
_type_: _description_
"""
unload_dicts = []
for lora, value in lora_dict.items():
lora_name = os.path.basename(lora).replace(".safetensors", "")
strength_offset = value.get("strength_offset", 0.0)
alpha = value.get("strength", 1.0)
alpha += strength_offset
lora_weight_str = value.get("lora_block_weight", "ALL")
lora = load_file(lora)
pipeline, unload_dict = update_pipeline_lora_model(
pipeline,
lora=lora,
device=device,
alpha=alpha,
lora_prefix_unet=lora_prefix_unet,
lora_prefix_text_encoder=lora_prefix_text_encoder,
lora_unet_layers=lora_unet_layers,
lora_block_weight_str=lora_weight_str,
need_unload=True,
)
print(
"Update LoRA {} with alpha {} and weight {}".format(
lora_name, alpha, lora_weight_str
)
)
unload_dicts += unload_dict
return pipeline, unload_dicts
def unload_lora(unload_dict: List[Dict[str, nn.Module]]):
for layer_data in unload_dict:
layer = layer_data["layer"]
added_weight = layer_data["added_weight"]
layer.weight.data -= added_weight
gc.collect()
torch.cuda.empty_cache()
def load_motion_lora_weights(
animation_pipeline,
motion_module_lora_configs=[],
):
for motion_module_lora_config in motion_module_lora_configs:
path, alpha = (
motion_module_lora_config["path"],
motion_module_lora_config["alpha"],
)
print(f"load motion LoRA from {path}")
motion_lora_state_dict = torch.load(path, map_location="cpu")
motion_lora_state_dict = (
motion_lora_state_dict["state_dict"]
if "state_dict" in motion_lora_state_dict
else motion_lora_state_dict
)
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(
animation_pipeline, motion_lora_state_dict, alpha
)
return animation_pipeline