import torch from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from external.llite.library import model_util from external.llite.library import sdxl_original_unet VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" # Diffusersの設定を読み込むための参照モデル DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" DIFFUSERS_SDXL_UNET_CONFIG = { "act_fn": "silu", "addition_embed_type": "text_time", "addition_embed_type_num_heads": 64, "addition_time_embed_dim": 256, "attention_head_dim": [5, 10, 20], "block_out_channels": [320, 640, 1280], "center_input_sample": False, "class_embed_type": None, "class_embeddings_concat": False, "conv_in_kernel": 3, "conv_out_kernel": 3, "cross_attention_dim": 2048, "cross_attention_norm": None, "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], "downsample_padding": 1, "dual_cross_attention": False, "encoder_hid_dim": None, "encoder_hid_dim_type": None, "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 4, "layers_per_block": 2, "mid_block_only_cross_attention": None, "mid_block_scale_factor": 1, "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, "num_attention_heads": None, "num_class_embeds": None, "only_cross_attention": False, "out_channels": 4, "projection_class_embeddings_input_dim": 2816, "resnet_out_scale_factor": 1.0, "resnet_skip_time_act": False, "resnet_time_scale_shift": "default", "sample_size": 128, "time_cond_proj_dim": None, "time_embedding_act_fn": None, "time_embedding_dim": None, "time_embedding_type": "positional", "timestep_post_act": None, "transformer_layers_per_block": [1, 2, 10], "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], "upcast_attention": False, "use_linear_projection": True, } def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): SDXL_KEY_PREFIX = "conditioner.embedders.1.model." # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す # logit_scaleはcheckpointの保存時に使用する def convert_key(key): # common conversion key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") key = key.replace(SDXL_KEY_PREFIX, "text_model.") if "resblocks" in key: # resblocks conversion key = key.replace(".resblocks.", ".layers.") if ".ln_" in key: key = key.replace(".ln_", ".layer_norm") elif ".mlp." in key: key = key.replace(".c_fc.", ".fc1.") key = key.replace(".c_proj.", ".fc2.") elif ".attn.out_proj" in key: key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") elif ".attn.in_proj" in key: key = None # 特殊なので後で処理する else: raise ValueError(f"unexpected key in SD: {key}") elif ".positional_embedding" in key: key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") elif ".text_projection" in key: key = key.replace("text_model.text_projection", "text_projection.weight") elif ".logit_scale" in key: key = None # 後で処理する elif ".token_embedding" in key: key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") elif ".ln_final" in key: key = key.replace(".ln_final", ".final_layer_norm") # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids elif ".embeddings.position_ids" in key: key = None # remove this key: make position_ids by ourselves return key keys = list(checkpoint.keys()) new_sd = {} for key in keys: new_key = convert_key(key) if new_key is None: continue new_sd[new_key] = checkpoint[key] # attnの変換 for key in keys: if ".resblocks" in key and ".attn.in_proj_" in key: # 三つに分割 values = torch.chunk(checkpoint[key], 3) key_suffix = ".weight" if "weight" in key else ".bias" key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") key_pfx = key_pfx.replace("_weight", "") key_pfx = key_pfx.replace("_bias", "") key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") new_sd[key_pfx + "q_proj" + key_suffix] = values[0] new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] # original SD にはないので、position_idsを追加 position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) new_sd["text_model.embeddings.position_ids"] = position_ids # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) # temporary workaround for text_projection.weight.weight for Playground-v2 if "text_projection.weight.weight" in new_sd: print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"] return new_sd, logit_scale # load state_dict without allocating new tensors def _load_state_dict_on_device(model, state_dict, device, dtype=None): # dtype will use fp32 as default missing_keys = list(model.state_dict().keys() - state_dict.keys()) unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) # similar to model.load_state_dict() if not missing_keys and not unexpected_keys: for k in list(state_dict.keys()): set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) return "" # error_msgs error_msgs: List[str] = [] if missing_keys: error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) if unexpected_keys: error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None try: state_dict = load_file(ckpt_path, device=map_location) except: state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: checkpoint = torch.load(ckpt_path, map_location=map_location) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] epoch = checkpoint.get("epoch", 0) global_step = checkpoint.get("global_step", 0) else: state_dict = checkpoint epoch = 0 global_step = 0 checkpoint = None # U-Net print("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) print("U-Net: ", info) # Text Encoders print("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=77, hidden_act="quick_gelu", layer_norm_eps=1e-05, dropout=0.0, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, bos_token_id=0, eos_token_id=2, model_type="clip_text_model", projection_dim=768, # torch_dtype="float32", # transformers_version="4.25.0.dev0", ) with init_empty_weights(): text_model1 = CLIPTextModel._from_config(text_model1_cfg) # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. text_model2_cfg = CLIPTextConfig( vocab_size=49408, hidden_size=1280, intermediate_size=5120, num_hidden_layers=32, num_attention_heads=20, max_position_embeddings=77, hidden_act="gelu", layer_norm_eps=1e-05, dropout=0.0, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, bos_token_id=0, eos_token_id=2, model_type="clip_text_model", projection_dim=1280, # torch_dtype="float32", # transformers_version="4.25.0.dev0", ) with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) print("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): if k.startswith("conditioner.embedders.0.transformer."): te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) # 一部のposition_idsがないモデルへの対応 / add position_ids for some models if "text_model.embeddings.position_ids" not in te1_sd: te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 print("text encoder 2:", info2) # prepare vae print("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) print("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) print("VAE:", info) ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info def make_unet_conversion_map(): unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) # if i > 0: commentout for sdxl # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0.", "norm1."), ("in_layers.2.", "conv1."), ("out_layers.0.", "norm2."), ("out_layers.3.", "conv2."), ("emb_layers.1.", "time_emb_proj."), ("skip_connection.", "conv_shortcut."), ] unet_conversion_map = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: unet_conversion_map.append((sd + sd_res, hf + hf_res)) else: unet_conversion_map.append((sd, hf)) for j in range(2): hf_time_embed_prefix = f"time_embedding.linear_{j+1}." sd_time_embed_prefix = f"time_embed.{j*2}." unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) for j in range(2): hf_label_embed_prefix = f"add_embedding.linear_{j+1}." sd_label_embed_prefix = f"label_emb.0.{j*2}." unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) unet_conversion_map.append(("out.0.", "conv_norm_out.")) unet_conversion_map.append(("out.2.", "conv_out.")) return unet_conversion_map def convert_diffusers_unet_state_dict_to_sdxl(du_sd): unet_conversion_map = make_unet_conversion_map() conversion_map = {hf: sd for sd, hf in unet_conversion_map} return convert_unet_state_dict(du_sd, conversion_map) def convert_unet_state_dict(src_sd, conversion_map): converted_sd = {} for src_key, value in src_sd.items(): # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す src_key_fragments = src_key.split(".")[:-1] # remove weight/bias while len(src_key_fragments) > 0: src_key_prefix = ".".join(src_key_fragments) + "." if src_key_prefix in conversion_map: converted_prefix = conversion_map[src_key_prefix] converted_key = converted_prefix + src_key[len(src_key_prefix) :] converted_sd[converted_key] = value break src_key_fragments.pop(-1) assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" return converted_sd def convert_sdxl_unet_state_dict_to_diffusers(sd): unet_conversion_map = make_unet_conversion_map() conversion_dict = {sd: hf for sd, hf in unet_conversion_map} return convert_unet_state_dict(sd, conversion_dict) def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): def convert_key(key): # position_idsの除去 if ".position_ids" in key: return None # common key = key.replace("text_model.encoder.", "transformer.") key = key.replace("text_model.", "") if "layers" in key: # resblocks conversion key = key.replace(".layers.", ".resblocks.") if ".layer_norm" in key: key = key.replace(".layer_norm", ".ln_") elif ".mlp." in key: key = key.replace(".fc1.", ".c_fc.") key = key.replace(".fc2.", ".c_proj.") elif ".self_attn.out_proj" in key: key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") elif ".self_attn." in key: key = None # 特殊なので後で処理する else: raise ValueError(f"unexpected key in DiffUsers model: {key}") elif ".position_embedding" in key: key = key.replace("embeddings.position_embedding.weight", "positional_embedding") elif ".token_embedding" in key: key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") elif "text_projection" in key: # no dot in key key = key.replace("text_projection.weight", "text_projection") elif "final_layer_norm" in key: key = key.replace("final_layer_norm", "ln_final") return key keys = list(checkpoint.keys()) new_sd = {} for key in keys: new_key = convert_key(key) if new_key is None: continue new_sd[new_key] = checkpoint[key] # attnの変換 for key in keys: if "layers" in key and "q_proj" in key: # 三つを結合 key_q = key key_k = key.replace("q_proj", "k_proj") key_v = key.replace("q_proj", "v_proj") value_q = checkpoint[key_q] value_k = checkpoint[key_k] value_v = checkpoint[key_v] value = torch.cat([value_q, value_k, value_v]) new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") new_sd[new_key] = value if logit_scale is not None: new_sd["logit_scale"] = logit_scale return new_sd def save_stable_diffusion_checkpoint( output_file, text_encoder1, text_encoder2, unet, epochs, steps, ckpt_info, vae, logit_scale, metadata, save_dtype=None, ): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k if save_dtype is not None: v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v # Convert the UNet model update_sd("model.diffusion_model.", unet.state_dict()) # Convert the text encoders update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) update_sd("conditioner.embedders.1.model.", text_enc2_dict) # Convert the VAE vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) update_sd("first_stage_model.", vae_dict) # Put together new checkpoint key_count = len(state_dict.keys()) new_ckpt = {"state_dict": state_dict} # epoch and global_step are sometimes not int if ckpt_info is not None: epochs += ckpt_info[0] steps += ckpt_info[1] new_ckpt["epoch"] = epochs new_ckpt["global_step"] = steps if model_util.is_safetensors(output_file): save_file(state_dict, output_file, metadata) else: torch.save(new_ckpt, output_file) return key_count def save_diffusers_checkpoint( output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None ): from diffusers import StableDiffusionXLPipeline # convert U-Net unet_sd = unet.state_dict() du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) if save_dtype is not None: diffusers_unet.to(save_dtype) diffusers_unet.load_state_dict(du_unet_sd) # create pipeline to save if pretrained_model_name_or_path is None: pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") if vae is None: vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") # prevent local path from being saved def remove_name_or_path(model): if hasattr(model, "config"): model.config._name_or_path = None model.config._name_or_path = None remove_name_or_path(diffusers_unet) remove_name_or_path(text_encoder1) remove_name_or_path(text_encoder2) remove_name_or_path(scheduler) remove_name_or_path(tokenizer1) remove_name_or_path(tokenizer2) remove_name_or_path(vae) pipeline = StableDiffusionXLPipeline( unet=diffusers_unet, text_encoder=text_encoder1, text_encoder_2=text_encoder2, vae=vae, scheduler=scheduler, tokenizer=tokenizer1, tokenizer_2=tokenizer2, ) if save_dtype is not None: pipeline.to(None, save_dtype) pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)