Spaces:
Build error
Build error
# some codes are copied from: | |
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/ | |
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. | |
# Changes made to the original code: | |
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer | |
# ------------------------------------------------------------------------------------------ | |
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | |
# ------------------------------------------------------------------------------------------ | |
import math | |
import os | |
import random | |
from typing import List, Tuple, Union | |
import torch | |
from torch import nn | |
class DyLoRAModule(torch.nn.Module): | |
""" | |
replaces forward method of the original Linear, instead of replacing the original Linear module. | |
""" | |
# NOTE: support dropout in future | |
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): | |
super().__init__() | |
self.lora_name = lora_name | |
self.lora_dim = lora_dim | |
self.unit = unit | |
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit" | |
if org_module.__class__.__name__ == "Conv2d": | |
in_dim = org_module.in_channels | |
out_dim = org_module.out_channels | |
else: | |
in_dim = org_module.in_features | |
out_dim = org_module.out_features | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha | |
self.scale = alpha / self.lora_dim | |
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える | |
self.is_conv2d = org_module.__class__.__name__ == "Conv2d" | |
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3) | |
if self.is_conv2d and self.is_conv2d_3x3: | |
kernel_size = org_module.kernel_size | |
self.stride = org_module.stride | |
self.padding = org_module.padding | |
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) | |
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) | |
else: | |
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) | |
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) | |
# same as microsoft's | |
for lora in self.lora_A: | |
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) | |
for lora in self.lora_B: | |
torch.nn.init.zeros_(lora) | |
self.multiplier = multiplier | |
self.org_module = org_module # remove in applying | |
def apply_to(self): | |
self.org_forward = self.org_module.forward | |
self.org_module.forward = self.forward | |
del self.org_module | |
def forward(self, x): | |
result = self.org_forward(x) | |
# specify the dynamic rank | |
trainable_rank = random.randint(0, self.lora_dim - 1) | |
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit | |
# 一部のパラメータを固定して、残りのパラメータを学習する | |
for i in range(0, trainable_rank): | |
self.lora_A[i].requires_grad = False | |
self.lora_B[i].requires_grad = False | |
for i in range(trainable_rank, trainable_rank + self.unit): | |
self.lora_A[i].requires_grad = True | |
self.lora_B[i].requires_grad = True | |
for i in range(trainable_rank + self.unit, self.lora_dim): | |
self.lora_A[i].requires_grad = False | |
self.lora_B[i].requires_grad = False | |
lora_A = torch.cat(tuple(self.lora_A), dim=0) | |
lora_B = torch.cat(tuple(self.lora_B), dim=1) | |
# calculate with lora_A and lora_B | |
if self.is_conv2d_3x3: | |
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding) | |
ab = torch.nn.functional.conv2d(ab, lora_B) | |
else: | |
ab = x | |
if self.is_conv2d: | |
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) | |
ab = torch.nn.functional.linear(ab, lora_A) | |
ab = torch.nn.functional.linear(ab, lora_B) | |
if self.is_conv2d: | |
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W) | |
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) | |
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) | |
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも | |
return result | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
# state dictを通常のLoRAと同じにする: | |
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える | |
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) | |
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) | |
if self.is_conv2d and not self.is_conv2d_3x3: | |
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) | |
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) | |
if self.is_conv2d and not self.is_conv2d_3x3: | |
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) | |
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() | |
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() | |
i = 0 | |
while True: | |
key_a = f"{self.lora_name}.lora_A.{i}" | |
key_b = f"{self.lora_name}.lora_B.{i}" | |
if key_a in sd: | |
sd.pop(key_a) | |
sd.pop(key_b) | |
else: | |
break | |
i += 1 | |
return sd | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた | |
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) | |
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) | |
if lora_A_weight is None or lora_B_weight is None: | |
if strict: | |
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") | |
else: | |
return | |
if self.is_conv2d and not self.is_conv2d_3x3: | |
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) | |
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) | |
state_dict.update( | |
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))} | |
) | |
state_dict.update( | |
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))} | |
) | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): | |
if network_dim is None: | |
network_dim = 4 # default | |
if network_alpha is None: | |
network_alpha = 1.0 | |
# extract dim/alpha for conv2d, and block dim | |
conv_dim = kwargs.get("conv_dim", None) | |
conv_alpha = kwargs.get("conv_alpha", None) | |
unit = kwargs.get("unit", None) | |
if conv_dim is not None: | |
conv_dim = int(conv_dim) | |
assert conv_dim == network_dim, "conv_dim must be same as network_dim" | |
if conv_alpha is None: | |
conv_alpha = 1.0 | |
else: | |
conv_alpha = float(conv_alpha) | |
if unit is not None: | |
unit = int(unit) | |
else: | |
unit = 1 | |
network = DyLoRANetwork( | |
text_encoder, | |
unet, | |
multiplier=multiplier, | |
lora_dim=network_dim, | |
alpha=network_alpha, | |
apply_to_conv=conv_dim is not None, | |
unit=unit, | |
varbose=True, | |
) | |
return network | |
# Create network from weights for inference, weights are not loaded here (because can be merged) | |
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): | |
if weights_sd is None: | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import load_file, safe_open | |
weights_sd = load_file(file) | |
else: | |
weights_sd = torch.load(file, map_location="cpu") | |
# get dim/alpha mapping | |
modules_dim = {} | |
modules_alpha = {} | |
for key, value in weights_sd.items(): | |
if "." not in key: | |
continue | |
lora_name = key.split(".")[0] | |
if "alpha" in key: | |
modules_alpha[lora_name] = value | |
elif "lora_down" in key: | |
dim = value.size()[0] | |
modules_dim[lora_name] = dim | |
# print(lora_name, value.size(), dim) | |
# support old LoRA without alpha | |
for key in modules_dim.keys(): | |
if key not in modules_alpha: | |
modules_alpha = modules_dim[key] | |
module_class = DyLoRAModule | |
network = DyLoRANetwork( | |
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class | |
) | |
return network, weights_sd | |
class DyLoRANetwork(torch.nn.Module): | |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] | |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] | |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
LORA_PREFIX_UNET = "lora_unet" | |
LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
def __init__( | |
self, | |
text_encoder, | |
unet, | |
multiplier=1.0, | |
lora_dim=4, | |
alpha=1, | |
apply_to_conv=False, | |
modules_dim=None, | |
modules_alpha=None, | |
unit=1, | |
module_class=DyLoRAModule, | |
varbose=False, | |
) -> None: | |
super().__init__() | |
self.multiplier = multiplier | |
self.lora_dim = lora_dim | |
self.alpha = alpha | |
self.apply_to_conv = apply_to_conv | |
if modules_dim is not None: | |
print(f"create LoRA network from weights") | |
else: | |
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") | |
if self.apply_to_conv: | |
print(f"apply LoRA to Conv2d with kernel size (3,3).") | |
# create module instances | |
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: | |
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER | |
loras = [] | |
for name, module in root_module.named_modules(): | |
if module.__class__.__name__ in target_replace_modules: | |
for child_name, child_module in module.named_modules(): | |
is_linear = child_module.__class__.__name__ == "Linear" | |
is_conv2d = child_module.__class__.__name__ == "Conv2d" | |
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) | |
if is_linear or is_conv2d: | |
lora_name = prefix + "." + name + "." + child_name | |
lora_name = lora_name.replace(".", "_") | |
dim = None | |
alpha = None | |
if modules_dim is not None: | |
if lora_name in modules_dim: | |
dim = modules_dim[lora_name] | |
alpha = modules_alpha[lora_name] | |
else: | |
if is_linear or is_conv2d_1x1 or apply_to_conv: | |
dim = self.lora_dim | |
alpha = self.alpha | |
if dim is None or dim == 0: | |
continue | |
# dropout and fan_in_fan_out is default | |
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) | |
loras.append(lora) | |
return loras | |
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) | |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") | |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights | |
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE | |
if modules_dim is not None or self.apply_to_conv: | |
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 | |
self.unet_loras = create_modules(True, unet, target_modules) | |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") | |
def set_multiplier(self, multiplier): | |
self.multiplier = multiplier | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.multiplier = self.multiplier | |
def load_weights(self, file): | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import load_file | |
weights_sd = load_file(file) | |
else: | |
weights_sd = torch.load(file, map_location="cpu") | |
info = self.load_state_dict(weights_sd, False) | |
return info | |
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): | |
if apply_text_encoder: | |
print("enable LoRA for text encoder") | |
else: | |
self.text_encoder_loras = [] | |
if apply_unet: | |
print("enable LoRA for U-Net") | |
else: | |
self.unet_loras = [] | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.apply_to() | |
self.add_module(lora.lora_name, lora) | |
""" | |
def merge_to(self, text_encoder, unet, weights_sd, dtype, device): | |
apply_text_encoder = apply_unet = False | |
for key in weights_sd.keys(): | |
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER): | |
apply_text_encoder = True | |
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET): | |
apply_unet = True | |
if apply_text_encoder: | |
print("enable LoRA for text encoder") | |
else: | |
self.text_encoder_loras = [] | |
if apply_unet: | |
print("enable LoRA for U-Net") | |
else: | |
self.unet_loras = [] | |
for lora in self.text_encoder_loras + self.unet_loras: | |
sd_for_lora = {} | |
for key in weights_sd.keys(): | |
if key.startswith(lora.lora_name): | |
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] | |
lora.merge_to(sd_for_lora, dtype, device) | |
print(f"weights are merged") | |
""" | |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): | |
self.requires_grad_(True) | |
all_params = [] | |
def enumerate_params(loras): | |
params = [] | |
for lora in loras: | |
params.extend(lora.parameters()) | |
return params | |
if self.text_encoder_loras: | |
param_data = {"params": enumerate_params(self.text_encoder_loras)} | |
if text_encoder_lr is not None: | |
param_data["lr"] = text_encoder_lr | |
all_params.append(param_data) | |
if self.unet_loras: | |
param_data = {"params": enumerate_params(self.unet_loras)} | |
if unet_lr is not None: | |
param_data["lr"] = unet_lr | |
all_params.append(param_data) | |
return all_params | |
def enable_gradient_checkpointing(self): | |
# not supported | |
pass | |
def prepare_grad_etc(self, text_encoder, unet): | |
self.requires_grad_(True) | |
def on_epoch_start(self, text_encoder, unet): | |
self.train() | |
def get_trainable_params(self): | |
return self.parameters() | |
def save_weights(self, file, dtype, metadata): | |
if metadata is not None and len(metadata) == 0: | |
metadata = None | |
state_dict = self.state_dict() | |
if dtype is not None: | |
for key in list(state_dict.keys()): | |
v = state_dict[key] | |
v = v.detach().clone().to("cpu").to(dtype) | |
state_dict[key] = v | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import save_file | |
from library import train_util | |
# Precalculate model hashes to save time on indexing | |
if metadata is None: | |
metadata = {} | |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) | |
metadata["sshs_model_hash"] = model_hash | |
metadata["sshs_legacy_hash"] = legacy_hash | |
save_file(state_dict, file, metadata) | |
else: | |
torch.save(state_dict, file) | |
# mask is a tensor with values from 0 to 1 | |
def set_region(self, sub_prompt_index, is_last_network, mask): | |
pass | |
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): | |
pass | |