AINxtGen's picture
Init
c22961b
from typing import Union, Tuple, Literal, Optional
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from torch import Tensor
from tqdm import tqdm
from toolkit.config_modules import LoRMConfig
conv = nn.Conv2d
lin = nn.Linear
_size_2_t = Union[int, Tuple[int, int]]
ExtractMode = Union[
'fixed',
'threshold',
'ratio',
'quantile',
'percentage'
]
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
]
CONV_MODULES = [
# 'Conv2d',
# 'LoRACompatibleConv'
]
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
# "ResnetBlock2D",
"Downsample2D",
"Upsample2D",
]
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
UNET_MODULES_TO_AVOID = [
]
# Low Rank Convolution
class LoRMCon2d(nn.Module):
def __init__(
self,
in_channels: int,
lorm_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 'same',
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super().__init__()
self.in_channels = in_channels
self.lorm_channels = lorm_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self.down = nn.Conv2d(
in_channels=in_channels,
out_channels=lorm_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False,
padding_mode=padding_mode,
device=device,
dtype=dtype
)
# Kernel size on the up is always 1x1.
# I don't think you could calculate a dual 3x3, or I can't at least
self.up = nn.Conv2d(
in_channels=lorm_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=1,
padding='same',
dilation=1,
groups=1,
bias=bias,
padding_mode='zeros',
device=device,
dtype=dtype
)
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
x = input
x = self.down(x)
x = self.up(x)
return x
class LoRMLinear(nn.Module):
def __init__(
self,
in_features: int,
lorm_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None
) -> None:
super().__init__()
self.in_features = in_features
self.lorm_features = lorm_features
self.out_features = out_features
self.down = nn.Linear(
in_features=in_features,
out_features=lorm_features,
bias=False,
device=device,
dtype=dtype
)
self.up = nn.Linear(
in_features=lorm_features,
out_features=out_features,
bias=bias,
# bias=True,
device=device,
dtype=dtype
)
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
x = input
x = self.down(x)
x = self.up(x)
return x
def extract_conv(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu'
) -> Tuple[Tensor, Tensor, int, Tensor]:
weight = weight.to(device)
out_ch, in_ch, kernel_size, _ = weight.shape
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
if mode == 'percentage':
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
original_params = out_ch * in_ch * kernel_size * kernel_size
desired_params = mode_param * original_params
# Solve for lora_rank from the equation
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
elif mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param).item()
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s).item()
elif mode == 'quantile' or mode == 'percentile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
lora_rank = int(out_ch / 2)
print(f"rank is higher than it should be")
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
# return weight, 'full'
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
del U, S, Vh, weight
return extract_weight_A, extract_weight_B, lora_rank, diff
def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu',
) -> Tuple[Tensor, Tensor, int, Tensor]:
weight = weight.to(device)
out_ch, in_ch = weight.shape
U, S, Vh = torch.linalg.svd(weight)
if mode == 'percentage':
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
desired_params = mode_param * out_ch * in_ch
# Solve for lora_rank from the equation
lora_rank = int(desired_params / (in_ch + out_ch))
elif mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param).item()
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s).item()
elif mode == 'quantile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
# print(f"rank is higher than it should be")
lora_rank = int(out_ch / 2)
# return weight, 'full'
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - U @ Vh).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
del U, S, Vh, weight
return extract_weight_A, extract_weight_B, lora_rank, diff
def replace_module_by_path(network, name, module):
"""Replace a module in a network by its name."""
name_parts = name.split('.')
current_module = network
for part in name_parts[:-1]:
current_module = getattr(current_module, part)
try:
setattr(current_module, name_parts[-1], module)
except Exception as e:
print(e)
def count_parameters(module):
return sum(p.numel() for p in module.parameters())
def compute_optimal_bias(original_module, linear_down, linear_up, X):
Y_original = original_module(X)
Y_approx = linear_up(linear_down(X))
E = Y_original - Y_approx
optimal_bias = E.mean(dim=0)
return optimal_bias
def format_with_commas(n):
return f"{n:,}"
def print_lorm_extract_details(
start_num_params: int,
end_num_params: int,
num_replaced: int,
):
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
lorm_ignore_if_contains = [
'proj_out', 'proj_in',
]
lorm_parameter_threshold = 1000000
@torch.no_grad()
def convert_diffusers_unet_to_lorm(
unet: UNet2DConditionModel,
config: LoRMConfig,
):
print('Converting UNet to LoRM UNet')
start_num_params = count_parameters(unet)
named_modules = list(unet.named_modules())
num_replaced = 0
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
layer_names_replaced = []
converted_modules = []
ignore_if_contains = [
'proj_out', 'proj_in',
]
for name, module in named_modules:
module_name = module.__class__.__name__
if module_name in UNET_TARGET_REPLACE_MODULE:
for child_name, child_module in module.named_modules():
new_module: Union[LoRMCon2d, LoRMLinear, None] = None
# if child name includes attn, skip it
combined_name = combined_name = f"{name}.{child_name}"
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
# pass
lorm_config = config.get_config_for_module(combined_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
if any([word in child_name for word in ignore_if_contains]):
pass
elif child_module.__class__.__name__ in LINEAR_MODULES:
if count_parameters(child_module) > parameter_threshold:
# dtype = child_module.weight.dtype
dtype = torch.float32
# extract and convert
down_weight, up_weight, lora_dim, diff = extract_linear(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None
if child_module.bias is not None:
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
# linear layer weights = (out_features, in_features)
new_module = LoRMLinear(
in_features=down_weight.shape[1],
lorm_features=lora_dim,
out_features=up_weight.shape[0],
bias=bias_weight is not None,
device=down_weight.device,
dtype=down_weight.dtype
)
# replace the weights
new_module.down.weight.data = down_weight
new_module.up.weight.data = up_weight
if bias_weight is not None:
new_module.up.bias.data = bias_weight
# else:
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
# bias_correction = compute_optimal_bias(
# child_module,
# new_module.down,
# new_module.up,
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
# )
# new_module.up.bias.data += bias_correction
elif child_module.__class__.__name__ in CONV_MODULES:
if count_parameters(child_module) > parameter_threshold:
dtype = child_module.weight.dtype
down_weight, up_weight, lora_dim, diff = extract_conv(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None
if child_module.bias is not None:
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
new_module = LoRMCon2d(
in_channels=down_weight.shape[1],
lorm_channels=lora_dim,
out_channels=up_weight.shape[0],
kernel_size=child_module.kernel_size,
dilation=child_module.dilation,
padding=child_module.padding,
padding_mode=child_module.padding_mode,
stride=child_module.stride,
bias=bias_weight is not None,
device=down_weight.device,
dtype=down_weight.dtype
)
# replace the weights
new_module.down.weight.data = down_weight
new_module.up.weight.data = up_weight
if bias_weight is not None:
new_module.up.bias.data = bias_weight
if new_module:
combined_name = f"{name}.{child_name}"
replace_module_by_path(unet, combined_name, new_module)
converted_modules.append(new_module)
num_replaced += 1
layer_names_replaced.append(
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
pbar.update(1)
pbar.close()
end_num_params = count_parameters(unet)
def sorting_key(s):
# Extract the number part, remove commas, and convert to integer
return int(s.split("-")[1].strip().replace(",", ""))
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
for layer_name in sorted_layer_names_replaced:
print(layer_name)
print_lorm_extract_details(
start_num_params=start_num_params,
end_num_params=end_num_params,
num_replaced=num_replaced,
)
return converted_modules