Spaces:
Paused
Paused
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
State dict utilities: utility methods for converting state dicts easily | |
""" | |
import enum | |
from .logging import get_logger | |
logger = get_logger(__name__) | |
class StateDictType(enum.Enum): | |
""" | |
The mode to use when converting state dicts. | |
""" | |
DIFFUSERS_OLD = "diffusers_old" | |
KOHYA_SS = "kohya_ss" | |
PEFT = "peft" | |
DIFFUSERS = "diffusers" | |
# We need to define a proper mapping for Unet since it uses different output keys than text encoder | |
# e.g. to_q_lora -> q_proj / to_q | |
UNET_TO_DIFFUSERS = { | |
".to_out_lora.up": ".to_out.0.lora_B", | |
".to_out_lora.down": ".to_out.0.lora_A", | |
".to_q_lora.down": ".to_q.lora_A", | |
".to_q_lora.up": ".to_q.lora_B", | |
".to_k_lora.down": ".to_k.lora_A", | |
".to_k_lora.up": ".to_k.lora_B", | |
".to_v_lora.down": ".to_v.lora_A", | |
".to_v_lora.up": ".to_v.lora_B", | |
".lora.up": ".lora_B", | |
".lora.down": ".lora_A", | |
} | |
DIFFUSERS_TO_PEFT = { | |
".q_proj.lora_linear_layer.up": ".q_proj.lora_B", | |
".q_proj.lora_linear_layer.down": ".q_proj.lora_A", | |
".k_proj.lora_linear_layer.up": ".k_proj.lora_B", | |
".k_proj.lora_linear_layer.down": ".k_proj.lora_A", | |
".v_proj.lora_linear_layer.up": ".v_proj.lora_B", | |
".v_proj.lora_linear_layer.down": ".v_proj.lora_A", | |
".out_proj.lora_linear_layer.up": ".out_proj.lora_B", | |
".out_proj.lora_linear_layer.down": ".out_proj.lora_A", | |
".lora_linear_layer.up": ".lora_B", | |
".lora_linear_layer.down": ".lora_A", | |
} | |
DIFFUSERS_OLD_TO_PEFT = { | |
".to_q_lora.up": ".q_proj.lora_B", | |
".to_q_lora.down": ".q_proj.lora_A", | |
".to_k_lora.up": ".k_proj.lora_B", | |
".to_k_lora.down": ".k_proj.lora_A", | |
".to_v_lora.up": ".v_proj.lora_B", | |
".to_v_lora.down": ".v_proj.lora_A", | |
".to_out_lora.up": ".out_proj.lora_B", | |
".to_out_lora.down": ".out_proj.lora_A", | |
".lora_linear_layer.up": ".lora_B", | |
".lora_linear_layer.down": ".lora_A", | |
} | |
PEFT_TO_DIFFUSERS = { | |
".q_proj.lora_B": ".q_proj.lora_linear_layer.up", | |
".q_proj.lora_A": ".q_proj.lora_linear_layer.down", | |
".k_proj.lora_B": ".k_proj.lora_linear_layer.up", | |
".k_proj.lora_A": ".k_proj.lora_linear_layer.down", | |
".v_proj.lora_B": ".v_proj.lora_linear_layer.up", | |
".v_proj.lora_A": ".v_proj.lora_linear_layer.down", | |
".out_proj.lora_B": ".out_proj.lora_linear_layer.up", | |
".out_proj.lora_A": ".out_proj.lora_linear_layer.down", | |
"to_k.lora_A": "to_k.lora.down", | |
"to_k.lora_B": "to_k.lora.up", | |
"to_q.lora_A": "to_q.lora.down", | |
"to_q.lora_B": "to_q.lora.up", | |
"to_v.lora_A": "to_v.lora.down", | |
"to_v.lora_B": "to_v.lora.up", | |
"to_out.0.lora_A": "to_out.0.lora.down", | |
"to_out.0.lora_B": "to_out.0.lora.up", | |
} | |
DIFFUSERS_OLD_TO_DIFFUSERS = { | |
".to_q_lora.up": ".q_proj.lora_linear_layer.up", | |
".to_q_lora.down": ".q_proj.lora_linear_layer.down", | |
".to_k_lora.up": ".k_proj.lora_linear_layer.up", | |
".to_k_lora.down": ".k_proj.lora_linear_layer.down", | |
".to_v_lora.up": ".v_proj.lora_linear_layer.up", | |
".to_v_lora.down": ".v_proj.lora_linear_layer.down", | |
".to_out_lora.up": ".out_proj.lora_linear_layer.up", | |
".to_out_lora.down": ".out_proj.lora_linear_layer.down", | |
} | |
PEFT_TO_KOHYA_SS = { | |
"lora_A": "lora_down", | |
"lora_B": "lora_up", | |
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys, | |
# adding prefixes and adding alpha values | |
# Check `convert_state_dict_to_kohya` for more | |
} | |
PEFT_STATE_DICT_MAPPINGS = { | |
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, | |
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, | |
} | |
DIFFUSERS_STATE_DICT_MAPPINGS = { | |
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, | |
StateDictType.PEFT: PEFT_TO_DIFFUSERS, | |
} | |
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} | |
KEYS_TO_ALWAYS_REPLACE = { | |
".processor.": ".", | |
} | |
def convert_state_dict(state_dict, mapping): | |
r""" | |
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. | |
Args: | |
state_dict (`dict[str, torch.Tensor]`): | |
The state dict to convert. | |
mapping (`dict[str, str]`): | |
The mapping to use for conversion, the mapping should be a dictionary with the following structure: | |
- key: the pattern to replace | |
- value: the pattern to replace with | |
Returns: | |
converted_state_dict (`dict`) | |
The converted state dict. | |
""" | |
converted_state_dict = {} | |
for k, v in state_dict.items(): | |
# First, filter out the keys that we always want to replace | |
for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): | |
if pattern in k: | |
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] | |
k = k.replace(pattern, new_pattern) | |
for pattern in mapping.keys(): | |
if pattern in k: | |
new_pattern = mapping[pattern] | |
k = k.replace(pattern, new_pattern) | |
break | |
converted_state_dict[k] = v | |
return converted_state_dict | |
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): | |
r""" | |
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or | |
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. | |
Args: | |
state_dict (`dict[str, torch.Tensor]`): | |
The state dict to convert. | |
original_type (`StateDictType`, *optional*): | |
The original type of the state dict, if not provided, the method will try to infer it automatically. | |
""" | |
if original_type is None: | |
# Old diffusers to PEFT | |
if any("to_out_lora" in k for k in state_dict.keys()): | |
original_type = StateDictType.DIFFUSERS_OLD | |
elif any("lora_linear_layer" in k for k in state_dict.keys()): | |
original_type = StateDictType.DIFFUSERS | |
else: | |
raise ValueError("Could not automatically infer state dict type") | |
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): | |
raise ValueError(f"Original type {original_type} is not supported") | |
mapping = PEFT_STATE_DICT_MAPPINGS[original_type] | |
return convert_state_dict(state_dict, mapping) | |
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): | |
r""" | |
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format | |
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will | |
return the state dict as is. | |
The method only supports the conversion from diffusers old, PEFT to diffusers new for now. | |
Args: | |
state_dict (`dict[str, torch.Tensor]`): | |
The state dict to convert. | |
original_type (`StateDictType`, *optional*): | |
The original type of the state dict, if not provided, the method will try to infer it automatically. | |
kwargs (`dict`, *args*): | |
Additional arguments to pass to the method. | |
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended | |
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | |
`get_peft_model_state_dict` method: | |
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | |
but we add it here in case we don't want to rely on that method. | |
""" | |
peft_adapter_name = kwargs.pop("adapter_name", None) | |
if peft_adapter_name is not None: | |
peft_adapter_name = "." + peft_adapter_name | |
else: | |
peft_adapter_name = "" | |
if original_type is None: | |
# Old diffusers to PEFT | |
if any("to_out_lora" in k for k in state_dict.keys()): | |
original_type = StateDictType.DIFFUSERS_OLD | |
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | |
original_type = StateDictType.PEFT | |
elif any("lora_linear_layer" in k for k in state_dict.keys()): | |
# nothing to do | |
return state_dict | |
else: | |
raise ValueError("Could not automatically infer state dict type") | |
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): | |
raise ValueError(f"Original type {original_type} is not supported") | |
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] | |
return convert_state_dict(state_dict, mapping) | |
def convert_unet_state_dict_to_peft(state_dict): | |
r""" | |
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys | |
""" | |
mapping = UNET_TO_DIFFUSERS | |
return convert_state_dict(state_dict, mapping) | |
def convert_all_state_dict_to_peft(state_dict): | |
r""" | |
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` | |
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` | |
""" | |
try: | |
peft_dict = convert_state_dict_to_peft(state_dict) | |
except Exception as e: | |
if str(e) == "Could not automatically infer state dict type": | |
peft_dict = convert_unet_state_dict_to_peft(state_dict) | |
else: | |
raise | |
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): | |
raise ValueError("Your LoRA was not converted to PEFT") | |
return peft_dict | |
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): | |
r""" | |
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. | |
The method only supports the conversion from PEFT to Kohya for now. | |
Args: | |
state_dict (`dict[str, torch.Tensor]`): | |
The state dict to convert. | |
original_type (`StateDictType`, *optional*): | |
The original type of the state dict, if not provided, the method will try to infer it automatically. | |
kwargs (`dict`, *args*): | |
Additional arguments to pass to the method. | |
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended | |
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | |
`get_peft_model_state_dict` method: | |
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | |
but we add it here in case we don't want to rely on that method. | |
""" | |
try: | |
import torch | |
except ImportError: | |
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") | |
raise | |
peft_adapter_name = kwargs.pop("adapter_name", None) | |
if peft_adapter_name is not None: | |
peft_adapter_name = "." + peft_adapter_name | |
else: | |
peft_adapter_name = "" | |
if original_type is None: | |
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | |
original_type = StateDictType.PEFT | |
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): | |
raise ValueError(f"Original type {original_type} is not supported") | |
# Use the convert_state_dict function with the appropriate mapping | |
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) | |
kohya_ss_state_dict = {} | |
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys | |
for kohya_key, weight in kohya_ss_partial_state_dict.items(): | |
if "text_encoder_2." in kohya_key: | |
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") | |
elif "text_encoder." in kohya_key: | |
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") | |
elif "unet" in kohya_key: | |
kohya_key = kohya_key.replace("unet", "lora_unet") | |
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) | |
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names | |
kohya_ss_state_dict[kohya_key] = weight | |
if "lora_down" in kohya_key: | |
alpha_key = f'{kohya_key.split(".")[0]}.alpha' | |
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) | |
return kohya_ss_state_dict | |