|
from typing import Union, Tuple, Literal, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from diffusers import UNet2DConditionModel |
|
from torch import Tensor |
|
from tqdm import tqdm |
|
|
|
from toolkit.config_modules import LoRMConfig |
|
|
|
conv = nn.Conv2d |
|
lin = nn.Linear |
|
_size_2_t = Union[int, Tuple[int, int]] |
|
|
|
ExtractMode = Union[ |
|
'fixed', |
|
'threshold', |
|
'ratio', |
|
'quantile', |
|
'percentage' |
|
] |
|
|
|
LINEAR_MODULES = [ |
|
'Linear', |
|
'LoRACompatibleLinear' |
|
] |
|
CONV_MODULES = [ |
|
|
|
|
|
] |
|
|
|
UNET_TARGET_REPLACE_MODULE = [ |
|
"Transformer2DModel", |
|
|
|
"Downsample2D", |
|
"Upsample2D", |
|
] |
|
|
|
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE |
|
|
|
UNET_TARGET_REPLACE_NAME = [ |
|
"conv_in", |
|
"conv_out", |
|
"time_embedding.linear_1", |
|
"time_embedding.linear_2", |
|
] |
|
|
|
UNET_MODULES_TO_AVOID = [ |
|
] |
|
|
|
|
|
|
|
class LoRMCon2d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
lorm_channels: int, |
|
out_channels: int, |
|
kernel_size: _size_2_t, |
|
stride: _size_2_t = 1, |
|
padding: Union[str, _size_2_t] = 'same', |
|
dilation: _size_2_t = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
padding_mode: str = 'zeros', |
|
device=None, |
|
dtype=None |
|
) -> None: |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.lorm_channels = lorm_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
self.padding_mode = padding_mode |
|
|
|
self.down = nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=lorm_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=False, |
|
padding_mode=padding_mode, |
|
device=device, |
|
dtype=dtype |
|
) |
|
|
|
|
|
|
|
|
|
self.up = nn.Conv2d( |
|
in_channels=lorm_channels, |
|
out_channels=out_channels, |
|
kernel_size=(1, 1), |
|
stride=1, |
|
padding='same', |
|
dilation=1, |
|
groups=1, |
|
bias=bias, |
|
padding_mode='zeros', |
|
device=device, |
|
dtype=dtype |
|
) |
|
|
|
def forward(self, input: Tensor, *args, **kwargs) -> Tensor: |
|
x = input |
|
x = self.down(x) |
|
x = self.up(x) |
|
return x |
|
|
|
|
|
class LoRMLinear(nn.Module): |
|
def __init__( |
|
self, |
|
in_features: int, |
|
lorm_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=None |
|
) -> None: |
|
super().__init__() |
|
self.in_features = in_features |
|
self.lorm_features = lorm_features |
|
self.out_features = out_features |
|
|
|
self.down = nn.Linear( |
|
in_features=in_features, |
|
out_features=lorm_features, |
|
bias=False, |
|
device=device, |
|
dtype=dtype |
|
|
|
) |
|
self.up = nn.Linear( |
|
in_features=lorm_features, |
|
out_features=out_features, |
|
bias=bias, |
|
|
|
device=device, |
|
dtype=dtype |
|
) |
|
|
|
def forward(self, input: Tensor, *args, **kwargs) -> Tensor: |
|
x = input |
|
x = self.down(x) |
|
x = self.up(x) |
|
return x |
|
|
|
|
|
def extract_conv( |
|
weight: Union[torch.Tensor, nn.Parameter], |
|
mode='fixed', |
|
mode_param=0, |
|
device='cpu' |
|
) -> Tuple[Tensor, Tensor, int, Tensor]: |
|
weight = weight.to(device) |
|
out_ch, in_ch, kernel_size, _ = weight.shape |
|
|
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) |
|
if mode == 'percentage': |
|
assert 0 <= mode_param <= 1 |
|
original_params = out_ch * in_ch * kernel_size * kernel_size |
|
desired_params = mode_param * original_params |
|
|
|
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) |
|
elif mode == 'fixed': |
|
lora_rank = mode_param |
|
elif mode == 'threshold': |
|
assert mode_param >= 0 |
|
lora_rank = torch.sum(S > mode_param).item() |
|
elif mode == 'ratio': |
|
assert 1 >= mode_param >= 0 |
|
min_s = torch.max(S) * mode_param |
|
lora_rank = torch.sum(S > min_s).item() |
|
elif mode == 'quantile' or mode == 'percentile': |
|
assert 1 >= mode_param >= 0 |
|
s_cum = torch.cumsum(S, dim=0) |
|
min_cum_sum = mode_param * torch.sum(S) |
|
lora_rank = torch.sum(s_cum < min_cum_sum).item() |
|
else: |
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') |
|
lora_rank = max(1, lora_rank) |
|
lora_rank = min(out_ch, in_ch, lora_rank) |
|
if lora_rank >= out_ch / 2: |
|
lora_rank = int(out_ch / 2) |
|
print(f"rank is higher than it should be") |
|
|
|
|
|
|
|
|
|
U = U[:, :lora_rank] |
|
S = S[:lora_rank] |
|
U = U @ torch.diag(S) |
|
Vh = Vh[:lora_rank, :] |
|
|
|
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() |
|
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() |
|
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() |
|
del U, S, Vh, weight |
|
return extract_weight_A, extract_weight_B, lora_rank, diff |
|
|
|
|
|
def extract_linear( |
|
weight: Union[torch.Tensor, nn.Parameter], |
|
mode='fixed', |
|
mode_param=0, |
|
device='cpu', |
|
) -> Tuple[Tensor, Tensor, int, Tensor]: |
|
weight = weight.to(device) |
|
out_ch, in_ch = weight.shape |
|
|
|
U, S, Vh = torch.linalg.svd(weight) |
|
|
|
if mode == 'percentage': |
|
assert 0 <= mode_param <= 1 |
|
desired_params = mode_param * out_ch * in_ch |
|
|
|
lora_rank = int(desired_params / (in_ch + out_ch)) |
|
elif mode == 'fixed': |
|
lora_rank = mode_param |
|
elif mode == 'threshold': |
|
assert mode_param >= 0 |
|
lora_rank = torch.sum(S > mode_param).item() |
|
elif mode == 'ratio': |
|
assert 1 >= mode_param >= 0 |
|
min_s = torch.max(S) * mode_param |
|
lora_rank = torch.sum(S > min_s).item() |
|
elif mode == 'quantile': |
|
assert 1 >= mode_param >= 0 |
|
s_cum = torch.cumsum(S, dim=0) |
|
min_cum_sum = mode_param * torch.sum(S) |
|
lora_rank = torch.sum(s_cum < min_cum_sum).item() |
|
else: |
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') |
|
lora_rank = max(1, lora_rank) |
|
lora_rank = min(out_ch, in_ch, lora_rank) |
|
if lora_rank >= out_ch / 2: |
|
|
|
lora_rank = int(out_ch / 2) |
|
|
|
|
|
|
|
|
|
U = U[:, :lora_rank] |
|
S = S[:lora_rank] |
|
U = U @ torch.diag(S) |
|
Vh = Vh[:lora_rank, :] |
|
|
|
diff = (weight - U @ Vh).detach() |
|
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() |
|
extract_weight_B = U.reshape(out_ch, lora_rank).detach() |
|
del U, S, Vh, weight |
|
return extract_weight_A, extract_weight_B, lora_rank, diff |
|
|
|
|
|
def replace_module_by_path(network, name, module): |
|
"""Replace a module in a network by its name.""" |
|
name_parts = name.split('.') |
|
current_module = network |
|
for part in name_parts[:-1]: |
|
current_module = getattr(current_module, part) |
|
try: |
|
setattr(current_module, name_parts[-1], module) |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
def count_parameters(module): |
|
return sum(p.numel() for p in module.parameters()) |
|
|
|
|
|
def compute_optimal_bias(original_module, linear_down, linear_up, X): |
|
Y_original = original_module(X) |
|
Y_approx = linear_up(linear_down(X)) |
|
E = Y_original - Y_approx |
|
|
|
optimal_bias = E.mean(dim=0) |
|
|
|
return optimal_bias |
|
|
|
|
|
def format_with_commas(n): |
|
return f"{n:,}" |
|
|
|
|
|
def print_lorm_extract_details( |
|
start_num_params: int, |
|
end_num_params: int, |
|
num_replaced: int, |
|
): |
|
start_formatted = format_with_commas(start_num_params) |
|
end_formatted = format_with_commas(end_num_params) |
|
num_replaced_formatted = format_with_commas(num_replaced) |
|
|
|
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) |
|
|
|
print(f"Convert UNet result:") |
|
print(f" - converted: {num_replaced:>{width},} modules") |
|
print(f" - start: {start_num_params:>{width},} params") |
|
print(f" - end: {end_num_params:>{width},} params") |
|
|
|
|
|
lorm_ignore_if_contains = [ |
|
'proj_out', 'proj_in', |
|
] |
|
|
|
lorm_parameter_threshold = 1000000 |
|
|
|
|
|
@torch.no_grad() |
|
def convert_diffusers_unet_to_lorm( |
|
unet: UNet2DConditionModel, |
|
config: LoRMConfig, |
|
): |
|
print('Converting UNet to LoRM UNet') |
|
start_num_params = count_parameters(unet) |
|
named_modules = list(unet.named_modules()) |
|
|
|
num_replaced = 0 |
|
|
|
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") |
|
layer_names_replaced = [] |
|
converted_modules = [] |
|
ignore_if_contains = [ |
|
'proj_out', 'proj_in', |
|
] |
|
|
|
for name, module in named_modules: |
|
module_name = module.__class__.__name__ |
|
if module_name in UNET_TARGET_REPLACE_MODULE: |
|
for child_name, child_module in module.named_modules(): |
|
new_module: Union[LoRMCon2d, LoRMLinear, None] = None |
|
|
|
combined_name = combined_name = f"{name}.{child_name}" |
|
|
|
|
|
|
|
lorm_config = config.get_config_for_module(combined_name) |
|
|
|
extract_mode = lorm_config.extract_mode |
|
extract_mode_param = lorm_config.extract_mode_param |
|
parameter_threshold = lorm_config.parameter_threshold |
|
|
|
if any([word in child_name for word in ignore_if_contains]): |
|
pass |
|
|
|
elif child_module.__class__.__name__ in LINEAR_MODULES: |
|
if count_parameters(child_module) > parameter_threshold: |
|
|
|
|
|
dtype = torch.float32 |
|
|
|
down_weight, up_weight, lora_dim, diff = extract_linear( |
|
weight=child_module.weight.clone().detach().float(), |
|
mode=extract_mode, |
|
mode_param=extract_mode_param, |
|
device=child_module.weight.device, |
|
) |
|
if down_weight is None: |
|
continue |
|
down_weight = down_weight.to(dtype=dtype) |
|
up_weight = up_weight.to(dtype=dtype) |
|
bias_weight = None |
|
if child_module.bias is not None: |
|
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) |
|
|
|
new_module = LoRMLinear( |
|
in_features=down_weight.shape[1], |
|
lorm_features=lora_dim, |
|
out_features=up_weight.shape[0], |
|
bias=bias_weight is not None, |
|
device=down_weight.device, |
|
dtype=down_weight.dtype |
|
) |
|
|
|
|
|
new_module.down.weight.data = down_weight |
|
new_module.up.weight.data = up_weight |
|
if bias_weight is not None: |
|
new_module.up.bias.data = bias_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif child_module.__class__.__name__ in CONV_MODULES: |
|
if count_parameters(child_module) > parameter_threshold: |
|
dtype = child_module.weight.dtype |
|
down_weight, up_weight, lora_dim, diff = extract_conv( |
|
weight=child_module.weight.clone().detach().float(), |
|
mode=extract_mode, |
|
mode_param=extract_mode_param, |
|
device=child_module.weight.device, |
|
) |
|
if down_weight is None: |
|
continue |
|
down_weight = down_weight.to(dtype=dtype) |
|
up_weight = up_weight.to(dtype=dtype) |
|
bias_weight = None |
|
if child_module.bias is not None: |
|
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) |
|
|
|
new_module = LoRMCon2d( |
|
in_channels=down_weight.shape[1], |
|
lorm_channels=lora_dim, |
|
out_channels=up_weight.shape[0], |
|
kernel_size=child_module.kernel_size, |
|
dilation=child_module.dilation, |
|
padding=child_module.padding, |
|
padding_mode=child_module.padding_mode, |
|
stride=child_module.stride, |
|
bias=bias_weight is not None, |
|
device=down_weight.device, |
|
dtype=down_weight.dtype |
|
) |
|
|
|
new_module.down.weight.data = down_weight |
|
new_module.up.weight.data = up_weight |
|
if bias_weight is not None: |
|
new_module.up.bias.data = bias_weight |
|
|
|
if new_module: |
|
combined_name = f"{name}.{child_name}" |
|
replace_module_by_path(unet, combined_name, new_module) |
|
converted_modules.append(new_module) |
|
num_replaced += 1 |
|
layer_names_replaced.append( |
|
f"{combined_name} - {format_with_commas(count_parameters(child_module))}") |
|
|
|
pbar.update(1) |
|
pbar.close() |
|
end_num_params = count_parameters(unet) |
|
|
|
def sorting_key(s): |
|
|
|
return int(s.split("-")[1].strip().replace(",", "")) |
|
|
|
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) |
|
for layer_name in sorted_layer_names_replaced: |
|
print(layer_name) |
|
|
|
print_lorm_extract_details( |
|
start_num_params=start_num_params, |
|
end_num_params=end_num_params, |
|
num_replaced=num_replaced, |
|
) |
|
|
|
return converted_modules |
|
|