from typing import Union, Tuple, Literal, Optional import torch import torch.nn as nn from diffusers import UNet2DConditionModel from torch import Tensor from tqdm import tqdm from toolkit.config_modules import LoRMConfig conv = nn.Conv2d lin = nn.Linear _size_2_t = Union[int, Tuple[int, int]] ExtractMode = Union[ 'fixed', 'threshold', 'ratio', 'quantile', 'percentage' ] LINEAR_MODULES = [ 'Linear', 'LoRACompatibleLinear' ] CONV_MODULES = [ # 'Conv2d', # 'LoRACompatibleConv' ] UNET_TARGET_REPLACE_MODULE = [ "Transformer2DModel", # "ResnetBlock2D", "Downsample2D", "Upsample2D", ] LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE UNET_TARGET_REPLACE_NAME = [ "conv_in", "conv_out", "time_embedding.linear_1", "time_embedding.linear_2", ] UNET_MODULES_TO_AVOID = [ ] # Low Rank Convolution class LoRMCon2d(nn.Module): def __init__( self, in_channels: int, lorm_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 'same', dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None ) -> None: super().__init__() self.in_channels = in_channels self.lorm_channels = lorm_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.padding_mode = padding_mode self.down = nn.Conv2d( in_channels=in_channels, out_channels=lorm_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode, device=device, dtype=dtype ) # Kernel size on the up is always 1x1. # I don't think you could calculate a dual 3x3, or I can't at least self.up = nn.Conv2d( in_channels=lorm_channels, out_channels=out_channels, kernel_size=(1, 1), stride=1, padding='same', dilation=1, groups=1, bias=bias, padding_mode='zeros', device=device, dtype=dtype ) def forward(self, input: Tensor, *args, **kwargs) -> Tensor: x = input x = self.down(x) x = self.up(x) return x class LoRMLinear(nn.Module): def __init__( self, in_features: int, lorm_features: int, out_features: int, bias: bool = True, device=None, dtype=None ) -> None: super().__init__() self.in_features = in_features self.lorm_features = lorm_features self.out_features = out_features self.down = nn.Linear( in_features=in_features, out_features=lorm_features, bias=False, device=device, dtype=dtype ) self.up = nn.Linear( in_features=lorm_features, out_features=out_features, bias=bias, # bias=True, device=device, dtype=dtype ) def forward(self, input: Tensor, *args, **kwargs) -> Tensor: x = input x = self.down(x) x = self.up(x) return x def extract_conv( weight: Union[torch.Tensor, nn.Parameter], mode='fixed', mode_param=0, device='cpu' ) -> Tuple[Tensor, Tensor, int, Tensor]: weight = weight.to(device) out_ch, in_ch, kernel_size, _ = weight.shape U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) if mode == 'percentage': assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. original_params = out_ch * in_ch * kernel_size * kernel_size desired_params = mode_param * original_params # Solve for lora_rank from the equation lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) elif mode == 'fixed': lora_rank = mode_param elif mode == 'threshold': assert mode_param >= 0 lora_rank = torch.sum(S > mode_param).item() elif mode == 'ratio': assert 1 >= mode_param >= 0 min_s = torch.max(S) * mode_param lora_rank = torch.sum(S > min_s).item() elif mode == 'quantile' or mode == 'percentile': assert 1 >= mode_param >= 0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum < min_cum_sum).item() else: raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') lora_rank = max(1, lora_rank) lora_rank = min(out_ch, in_ch, lora_rank) if lora_rank >= out_ch / 2: lora_rank = int(out_ch / 2) print(f"rank is higher than it should be") # print(f"Skipping layer as determined rank is too high") # return None, None, None, None # return weight, 'full' U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() del U, S, Vh, weight return extract_weight_A, extract_weight_B, lora_rank, diff def extract_linear( weight: Union[torch.Tensor, nn.Parameter], mode='fixed', mode_param=0, device='cpu', ) -> Tuple[Tensor, Tensor, int, Tensor]: weight = weight.to(device) out_ch, in_ch = weight.shape U, S, Vh = torch.linalg.svd(weight) if mode == 'percentage': assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. desired_params = mode_param * out_ch * in_ch # Solve for lora_rank from the equation lora_rank = int(desired_params / (in_ch + out_ch)) elif mode == 'fixed': lora_rank = mode_param elif mode == 'threshold': assert mode_param >= 0 lora_rank = torch.sum(S > mode_param).item() elif mode == 'ratio': assert 1 >= mode_param >= 0 min_s = torch.max(S) * mode_param lora_rank = torch.sum(S > min_s).item() elif mode == 'quantile': assert 1 >= mode_param >= 0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum < min_cum_sum).item() else: raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') lora_rank = max(1, lora_rank) lora_rank = min(out_ch, in_ch, lora_rank) if lora_rank >= out_ch / 2: # print(f"rank is higher than it should be") lora_rank = int(out_ch / 2) # return weight, 'full' # print(f"Skipping layer as determined rank is too high") # return None, None, None, None U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] diff = (weight - U @ Vh).detach() extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() extract_weight_B = U.reshape(out_ch, lora_rank).detach() del U, S, Vh, weight return extract_weight_A, extract_weight_B, lora_rank, diff def replace_module_by_path(network, name, module): """Replace a module in a network by its name.""" name_parts = name.split('.') current_module = network for part in name_parts[:-1]: current_module = getattr(current_module, part) try: setattr(current_module, name_parts[-1], module) except Exception as e: print(e) def count_parameters(module): return sum(p.numel() for p in module.parameters()) def compute_optimal_bias(original_module, linear_down, linear_up, X): Y_original = original_module(X) Y_approx = linear_up(linear_down(X)) E = Y_original - Y_approx optimal_bias = E.mean(dim=0) return optimal_bias def format_with_commas(n): return f"{n:,}" def print_lorm_extract_details( start_num_params: int, end_num_params: int, num_replaced: int, ): start_formatted = format_with_commas(start_num_params) end_formatted = format_with_commas(end_num_params) num_replaced_formatted = format_with_commas(num_replaced) width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) print(f"Convert UNet result:") print(f" - converted: {num_replaced:>{width},} modules") print(f" - start: {start_num_params:>{width},} params") print(f" - end: {end_num_params:>{width},} params") lorm_ignore_if_contains = [ 'proj_out', 'proj_in', ] lorm_parameter_threshold = 1000000 @torch.no_grad() def convert_diffusers_unet_to_lorm( unet: UNet2DConditionModel, config: LoRMConfig, ): print('Converting UNet to LoRM UNet') start_num_params = count_parameters(unet) named_modules = list(unet.named_modules()) num_replaced = 0 pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") layer_names_replaced = [] converted_modules = [] ignore_if_contains = [ 'proj_out', 'proj_in', ] for name, module in named_modules: module_name = module.__class__.__name__ if module_name in UNET_TARGET_REPLACE_MODULE: for child_name, child_module in module.named_modules(): new_module: Union[LoRMCon2d, LoRMLinear, None] = None # if child name includes attn, skip it combined_name = combined_name = f"{name}.{child_name}" # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: # pass lorm_config = config.get_config_for_module(combined_name) extract_mode = lorm_config.extract_mode extract_mode_param = lorm_config.extract_mode_param parameter_threshold = lorm_config.parameter_threshold if any([word in child_name for word in ignore_if_contains]): pass elif child_module.__class__.__name__ in LINEAR_MODULES: if count_parameters(child_module) > parameter_threshold: # dtype = child_module.weight.dtype dtype = torch.float32 # extract and convert down_weight, up_weight, lora_dim, diff = extract_linear( weight=child_module.weight.clone().detach().float(), mode=extract_mode, mode_param=extract_mode_param, device=child_module.weight.device, ) if down_weight is None: continue down_weight = down_weight.to(dtype=dtype) up_weight = up_weight.to(dtype=dtype) bias_weight = None if child_module.bias is not None: bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) # linear layer weights = (out_features, in_features) new_module = LoRMLinear( in_features=down_weight.shape[1], lorm_features=lora_dim, out_features=up_weight.shape[0], bias=bias_weight is not None, device=down_weight.device, dtype=down_weight.dtype ) # replace the weights new_module.down.weight.data = down_weight new_module.up.weight.data = up_weight if bias_weight is not None: new_module.up.bias.data = bias_weight # else: # new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data) # bias_correction = compute_optimal_bias( # child_module, # new_module.down, # new_module.up, # torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype) # ) # new_module.up.bias.data += bias_correction elif child_module.__class__.__name__ in CONV_MODULES: if count_parameters(child_module) > parameter_threshold: dtype = child_module.weight.dtype down_weight, up_weight, lora_dim, diff = extract_conv( weight=child_module.weight.clone().detach().float(), mode=extract_mode, mode_param=extract_mode_param, device=child_module.weight.device, ) if down_weight is None: continue down_weight = down_weight.to(dtype=dtype) up_weight = up_weight.to(dtype=dtype) bias_weight = None if child_module.bias is not None: bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) new_module = LoRMCon2d( in_channels=down_weight.shape[1], lorm_channels=lora_dim, out_channels=up_weight.shape[0], kernel_size=child_module.kernel_size, dilation=child_module.dilation, padding=child_module.padding, padding_mode=child_module.padding_mode, stride=child_module.stride, bias=bias_weight is not None, device=down_weight.device, dtype=down_weight.dtype ) # replace the weights new_module.down.weight.data = down_weight new_module.up.weight.data = up_weight if bias_weight is not None: new_module.up.bias.data = bias_weight if new_module: combined_name = f"{name}.{child_name}" replace_module_by_path(unet, combined_name, new_module) converted_modules.append(new_module) num_replaced += 1 layer_names_replaced.append( f"{combined_name} - {format_with_commas(count_parameters(child_module))}") pbar.update(1) pbar.close() end_num_params = count_parameters(unet) def sorting_key(s): # Extract the number part, remove commas, and convert to integer return int(s.split("-")[1].strip().replace(",", "")) sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) for layer_name in sorted_layer_names_replaced: print(layer_name) print_lorm_extract_details( start_num_params=start_num_params, end_num_params=end_num_params, num_replaced=num_replaced, ) return converted_modules