|
|
|
|
|
import math |
|
import os |
|
from typing import Dict, List, Optional, Tuple, Type, Union |
|
from diffusers import AutoencoderKL |
|
from transformers import CLIPTextModel |
|
import numpy as np |
|
import torch |
|
import re |
|
|
|
|
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") |
|
|
|
|
|
class OFTModule(torch.nn.Module): |
|
""" |
|
replaces forward method of the original Linear, instead of replacing the original Linear module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
oft_name, |
|
org_module: torch.nn.Module, |
|
multiplier=1.0, |
|
dim=4, |
|
alpha=1, |
|
): |
|
""" |
|
dim -> num blocks |
|
alpha -> constraint |
|
""" |
|
super().__init__() |
|
self.oft_name = oft_name |
|
|
|
self.num_blocks = dim |
|
|
|
if "Linear" in org_module.__class__.__name__: |
|
out_dim = org_module.out_features |
|
elif "Conv" in org_module.__class__.__name__: |
|
out_dim = org_module.out_channels |
|
|
|
if type(alpha) == torch.Tensor: |
|
alpha = alpha.detach().numpy() |
|
self.constraint = alpha * out_dim |
|
self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
|
self.block_size = out_dim // self.num_blocks |
|
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) |
|
|
|
self.out_dim = out_dim |
|
self.shape = org_module.weight.shape |
|
|
|
self.multiplier = multiplier |
|
self.org_module = [org_module] |
|
|
|
def apply_to(self): |
|
self.org_forward = self.org_module[0].forward |
|
self.org_module[0].forward = self.forward |
|
|
|
def get_weight(self, multiplier=None): |
|
if multiplier is None: |
|
multiplier = self.multiplier |
|
|
|
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) |
|
norm_Q = torch.norm(block_Q.flatten()) |
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint) |
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) |
|
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) |
|
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) |
|
|
|
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I |
|
R = torch.block_diag(*block_R_weighted) |
|
|
|
return R |
|
|
|
def forward(self, x, scale=None): |
|
x = self.org_forward(x) |
|
if self.multiplier == 0.0: |
|
return x |
|
|
|
R = self.get_weight().to(x.device, dtype=x.dtype) |
|
if x.dim() == 4: |
|
x = x.permute(0, 2, 3, 1) |
|
x = torch.matmul(x, R) |
|
x = x.permute(0, 3, 1, 2) |
|
else: |
|
x = torch.matmul(x, R) |
|
return x |
|
|
|
|
|
class OFTInfModule(OFTModule): |
|
def __init__( |
|
self, |
|
oft_name, |
|
org_module: torch.nn.Module, |
|
multiplier=1.0, |
|
dim=4, |
|
alpha=1, |
|
**kwargs, |
|
): |
|
|
|
super().__init__(oft_name, org_module, multiplier, dim, alpha) |
|
self.enabled = True |
|
self.network: OFTNetwork = None |
|
|
|
def set_network(self, network): |
|
self.network = network |
|
|
|
def forward(self, x, scale=None): |
|
if not self.enabled: |
|
return self.org_forward(x) |
|
return super().forward(x, scale) |
|
|
|
def merge_to(self, multiplier=None, sign=1): |
|
R = self.get_weight(multiplier) * sign |
|
|
|
|
|
org_sd = self.org_module[0].state_dict() |
|
org_weight = org_sd["weight"] |
|
R = R.to(org_weight.device, dtype=org_weight.dtype) |
|
|
|
if org_weight.dim() == 4: |
|
weight = torch.einsum("oihw, op -> pihw", org_weight, R) |
|
else: |
|
weight = torch.einsum("oi, op -> pi", org_weight, R) |
|
|
|
|
|
org_sd["weight"] = weight |
|
self.org_module[0].load_state_dict(org_sd) |
|
|
|
|
|
def create_network( |
|
multiplier: float, |
|
network_dim: Optional[int], |
|
network_alpha: Optional[float], |
|
vae: AutoencoderKL, |
|
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], |
|
unet, |
|
neuron_dropout: Optional[float] = None, |
|
**kwargs, |
|
): |
|
if network_dim is None: |
|
network_dim = 4 |
|
if network_alpha is None: |
|
network_alpha = 1.0 |
|
|
|
enable_all_linear = kwargs.get("enable_all_linear", None) |
|
enable_conv = kwargs.get("enable_conv", None) |
|
if enable_all_linear is not None: |
|
enable_all_linear = bool(enable_all_linear) |
|
if enable_conv is not None: |
|
enable_conv = bool(enable_conv) |
|
|
|
network = OFTNetwork( |
|
text_encoder, |
|
unet, |
|
multiplier=multiplier, |
|
dim=network_dim, |
|
alpha=network_alpha, |
|
enable_all_linear=enable_all_linear, |
|
enable_conv=enable_conv, |
|
varbose=True, |
|
) |
|
return network |
|
|
|
|
|
|
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): |
|
if weights_sd is None: |
|
if os.path.splitext(file)[1] == ".safetensors": |
|
from safetensors.torch import load_file, safe_open |
|
|
|
weights_sd = load_file(file) |
|
else: |
|
weights_sd = torch.load(file, map_location="cpu") |
|
|
|
|
|
dim = None |
|
alpha = None |
|
has_conv2d = None |
|
all_linear = None |
|
for name, param in weights_sd.items(): |
|
if name.endswith(".alpha"): |
|
if alpha is None: |
|
alpha = param.item() |
|
else: |
|
if dim is None: |
|
dim = param.size()[0] |
|
if has_conv2d is None and param.dim() == 4: |
|
has_conv2d = True |
|
if all_linear is None: |
|
if param.dim() == 3 and "attn" not in name: |
|
all_linear = True |
|
if dim is not None and alpha is not None and has_conv2d is not None: |
|
break |
|
if has_conv2d is None: |
|
has_conv2d = False |
|
if all_linear is None: |
|
all_linear = False |
|
|
|
module_class = OFTInfModule if for_inference else OFTModule |
|
network = OFTNetwork( |
|
text_encoder, |
|
unet, |
|
multiplier=multiplier, |
|
dim=dim, |
|
alpha=alpha, |
|
enable_all_linear=all_linear, |
|
enable_conv=has_conv2d, |
|
module_class=module_class, |
|
) |
|
return network, weights_sd |
|
|
|
|
|
class OFTNetwork(torch.nn.Module): |
|
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] |
|
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] |
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
|
OFT_PREFIX_UNET = "oft_unet" |
|
|
|
def __init__( |
|
self, |
|
text_encoder: Union[List[CLIPTextModel], CLIPTextModel], |
|
unet, |
|
multiplier: float = 1.0, |
|
dim: int = 4, |
|
alpha: float = 1, |
|
enable_all_linear: Optional[bool] = False, |
|
enable_conv: Optional[bool] = False, |
|
module_class: Type[object] = OFTModule, |
|
varbose: Optional[bool] = False, |
|
) -> None: |
|
super().__init__() |
|
self.multiplier = multiplier |
|
|
|
self.dim = dim |
|
self.alpha = alpha |
|
|
|
print( |
|
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" |
|
) |
|
|
|
|
|
def create_modules( |
|
root_module: torch.nn.Module, |
|
target_replace_modules: List[torch.nn.Module], |
|
) -> List[OFTModule]: |
|
prefix = self.OFT_PREFIX_UNET |
|
ofts = [] |
|
for name, module in root_module.named_modules(): |
|
if module.__class__.__name__ in target_replace_modules: |
|
for child_name, child_module in module.named_modules(): |
|
is_linear = "Linear" in child_module.__class__.__name__ |
|
is_conv2d = "Conv2d" in child_module.__class__.__name__ |
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
|
|
|
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): |
|
oft_name = prefix + "." + name + "." + child_name |
|
oft_name = oft_name.replace(".", "_") |
|
|
|
|
|
oft = module_class( |
|
oft_name, |
|
child_module, |
|
self.multiplier, |
|
dim, |
|
alpha, |
|
) |
|
ofts.append(oft) |
|
return ofts |
|
|
|
|
|
if enable_all_linear: |
|
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR |
|
else: |
|
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY |
|
if enable_conv: |
|
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
|
|
|
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) |
|
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") |
|
|
|
|
|
names = set() |
|
for oft in self.unet_ofts: |
|
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" |
|
names.add(oft.oft_name) |
|
|
|
def set_multiplier(self, multiplier): |
|
self.multiplier = multiplier |
|
for oft in self.unet_ofts: |
|
oft.multiplier = self.multiplier |
|
|
|
def load_weights(self, file): |
|
if os.path.splitext(file)[1] == ".safetensors": |
|
from safetensors.torch import load_file |
|
|
|
weights_sd = load_file(file) |
|
else: |
|
weights_sd = torch.load(file, map_location="cpu") |
|
|
|
info = self.load_state_dict(weights_sd, False) |
|
return info |
|
|
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): |
|
assert apply_unet, "apply_unet must be True" |
|
|
|
for oft in self.unet_ofts: |
|
oft.apply_to() |
|
self.add_module(oft.oft_name, oft) |
|
|
|
|
|
def is_mergeable(self): |
|
return True |
|
|
|
|
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device): |
|
print("enable OFT for U-Net") |
|
|
|
for oft in self.unet_ofts: |
|
sd_for_lora = {} |
|
for key in weights_sd.keys(): |
|
if key.startswith(oft.oft_name): |
|
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] |
|
oft.load_state_dict(sd_for_lora, False) |
|
oft.merge_to() |
|
|
|
print(f"weights are merged") |
|
|
|
|
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): |
|
self.requires_grad_(True) |
|
all_params = [] |
|
|
|
def enumerate_params(ofts): |
|
params = [] |
|
for oft in ofts: |
|
params.extend(oft.parameters()) |
|
|
|
|
|
num_params = 0 |
|
for p in params: |
|
num_params += p.numel() |
|
print(f"OFT params: {num_params}") |
|
return params |
|
|
|
param_data = {"params": enumerate_params(self.unet_ofts)} |
|
if unet_lr is not None: |
|
param_data["lr"] = unet_lr |
|
all_params.append(param_data) |
|
|
|
return all_params |
|
|
|
def enable_gradient_checkpointing(self): |
|
|
|
pass |
|
|
|
def prepare_grad_etc(self, text_encoder, unet): |
|
self.requires_grad_(True) |
|
|
|
def on_epoch_start(self, text_encoder, unet): |
|
self.train() |
|
|
|
def get_trainable_params(self): |
|
return self.parameters() |
|
|
|
def save_weights(self, file, dtype, metadata): |
|
if metadata is not None and len(metadata) == 0: |
|
metadata = None |
|
|
|
state_dict = self.state_dict() |
|
|
|
if dtype is not None: |
|
for key in list(state_dict.keys()): |
|
v = state_dict[key] |
|
v = v.detach().clone().to("cpu").to(dtype) |
|
state_dict[key] = v |
|
|
|
if os.path.splitext(file)[1] == ".safetensors": |
|
from safetensors.torch import save_file |
|
from library import train_util |
|
|
|
|
|
if metadata is None: |
|
metadata = {} |
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) |
|
metadata["sshs_model_hash"] = model_hash |
|
metadata["sshs_legacy_hash"] = legacy_hash |
|
|
|
save_file(state_dict, file, metadata) |
|
else: |
|
torch.save(state_dict, file) |
|
|
|
def backup_weights(self): |
|
|
|
ofts: List[OFTInfModule] = self.unet_ofts |
|
for oft in ofts: |
|
org_module = oft.org_module[0] |
|
if not hasattr(org_module, "_lora_org_weight"): |
|
sd = org_module.state_dict() |
|
org_module._lora_org_weight = sd["weight"].detach().clone() |
|
org_module._lora_restored = True |
|
|
|
def restore_weights(self): |
|
|
|
ofts: List[OFTInfModule] = self.unet_ofts |
|
for oft in ofts: |
|
org_module = oft.org_module[0] |
|
if not org_module._lora_restored: |
|
sd = org_module.state_dict() |
|
sd["weight"] = org_module._lora_org_weight |
|
org_module.load_state_dict(sd) |
|
org_module._lora_restored = True |
|
|
|
def pre_calculation(self): |
|
|
|
ofts: List[OFTInfModule] = self.unet_ofts |
|
for oft in ofts: |
|
org_module = oft.org_module[0] |
|
oft.merge_to() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
org_module._lora_restored = False |
|
oft.enabled = False |
|
|