import torch import time import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui from tqdm import tqdm from backend import memory_management, utils from backend.args import dynamic_args class ForgeLoraCollection: # TODO pass extra_weight_calculators = {} lora_utils_forge = ForgeLoraCollection() lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui] def get_function(function_name: str): for lora_collection in lora_collection_priority: if hasattr(lora_collection, function_name): return getattr(lora_collection, function_name) def load_lora(lora, to_load): patch_dict, remaining_dict = get_function('load_lora')(lora, to_load) return patch_dict, remaining_dict def model_lora_keys_clip(model, key_map={}): return get_function('model_lora_keys_clip')(model, key_map) def model_lora_keys_unet(model, key_map={}): return get_function('model_lora_keys_unet')(model, key_map) @torch.inference_mode() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype): # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33 dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype) lora_diff *= alpha weight_calc = weight + lora_diff.type(weight.dtype) weight_norm = ( weight_calc.transpose(0, 1) .reshape(weight_calc.shape[1], -1) .norm(dim=1, keepdim=True) .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) .transpose(0, 1) ) weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: weight_calc -= weight weight += strength * weight_calc else: weight[:] = weight_calc return weight @torch.inference_mode() def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32): # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446 weight_original_dtype = weight.dtype weight = weight.to(dtype=computation_dtype) for p in patches: strength = p[0] v = p[1] strength_model = p[2] offset = p[3] function = p[4] if function is None: function = lambda a: a old_weight = None if offset is not None: old_weight = weight weight = weight.narrow(offset[0], offset[1], offset[2]) if strength_model != 1.0: weight *= strength_model if isinstance(v, list): v = (merge_lora_to_weight(v[1:], v[0].clone(), key),) patch_type = '' if len(v) == 1: patch_type = "diff" elif len(v) == 2: patch_type = v[0] v = v[1] if patch_type == "diff": w1 = v[0] if strength != 0.0: if w1.shape != weight.shape: if w1.ndim == weight.ndim == 4: new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] print(f'Merged with {key} channel changed to {new_shape}') new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) new_weight = torch.zeros(size=new_shape).to(weight) new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff new_weight = new_weight.contiguous().clone() weight = new_weight else: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype) mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype) dora_scale = v[4] if v[2] is not None: alpha = v[2] / mat2.shape[0] else: alpha = 1.0 if v[3] is not None: mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] w1_b = v[4] w2_a = v[5] w2_b = v[6] t2 = v[7] dora_scale = v[8] dim = None if w1 is None: dim = w1_b.shape[0] w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype), memory_management.cast_to_device(w1_b, weight.device, computation_dtype)) else: w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype) if w2 is None: dim = w2_b.shape[0] if t2 is None: w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype), memory_management.cast_to_device(w2_b, weight.device, computation_dtype)) else: w2 = torch.einsum('i j k l, j r, i p -> p r k l', memory_management.cast_to_device(t2, weight.device, computation_dtype), memory_management.cast_to_device(w2_b, weight.device, computation_dtype), memory_management.cast_to_device(w2_a, weight.device, computation_dtype)) else: w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: alpha = v[2] / dim else: alpha = 1.0 try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: alpha = v[2] / w1b.shape[0] else: alpha = 1.0 w2a = v[3] w2b = v[4] dora_scale = v[7] if v[5] is not None: t1 = v[5] t2 = v[6] m1 = torch.einsum('i j k l, j r, i p -> p r k l', memory_management.cast_to_device(t1, weight.device, computation_dtype), memory_management.cast_to_device(w1b, weight.device, computation_dtype), memory_management.cast_to_device(w1a, weight.device, computation_dtype)) m2 = torch.einsum('i j k l, j r, i p -> p r k l', memory_management.cast_to_device(t2, weight.device, computation_dtype), memory_management.cast_to_device(w2b, weight.device, computation_dtype), memory_management.cast_to_device(w2a, weight.device, computation_dtype)) else: m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype), memory_management.cast_to_device(w1b, weight.device, computation_dtype)) m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype), memory_management.cast_to_device(w2b, weight.device, computation_dtype)) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e elif patch_type == "glora": if v[4] is not None: alpha = v[4] / v[0].shape[0] else: alpha = 1.0 dora_scale = v[5] a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype) a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype) b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype) b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype) try: lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e elif patch_type in extra_weight_calculators: weight = extra_weight_calculators[patch_type](weight, strength, v) else: print("patch type not recognized {} {}".format(patch_type, key)) if old_weight is not None: weight = old_weight weight = weight.to(dtype=weight_original_dtype) return weight from backend import operations class LoraLoader: def __init__(self, model): self.model = model self.patches = {} self.backup = {} self.online_backup = [] self.dirty = False def clear_patches(self): self.patches.clear() self.dirty = True return def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = set() model_sd = self.model.state_dict() for k in patches: offset = None function = None if isinstance(k, str): key = k else: offset = k[1] key = k[0] if len(k) > 2: function = k[2] if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) current_patches.append([strength_patch, patches[k], strength_model, offset, function]) self.patches[key] = current_patches self.dirty = True return list(p) @torch.inference_mode() def refresh(self, target_device=None, offload_device=torch.device('cpu')): if not self.dirty: return self.dirty = False execution_start_time = time.perf_counter() # Restore for m in set(self.online_backup): del m.forge_online_loras self.online_backup = [] for k, w in self.backup.items(): if not isinstance(w, torch.nn.Parameter): # In very few cases w = torch.nn.Parameter(w, requires_grad=False) utils.set_attr_raw(self.model, k, w) self.backup = {} online_mode = dynamic_args.get('online_lora', False) # Patch for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches): try: parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") if key not in self.backup: self.backup[key] = weight.to(device=offload_device) if online_mode: if not hasattr(parent_layer, 'forge_online_loras'): parent_layer.forge_online_loras = {} parent_layer.forge_online_loras[child_key] = current_patches self.online_backup.append(parent_layer) continue bnb_layer = None if operations.bnb_avaliable: if hasattr(weight, 'bnb_quantized'): bnb_layer = parent_layer if weight.bnb_quantized: weight_original_device = weight.device if target_device is not None: assert target_device.type == 'cuda', 'BNB Must use CUDA!' weight = weight.to(target_device) else: weight = weight.cuda() from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) if target_device is None: weight = weight.to(device=weight_original_device) else: weight = weight.data if target_device is not None: try: weight = weight.to(device=target_device) except: print('Moving layer weight failed. Retrying by offloading models.') self.model.to(device=offload_device) memory_management.soft_empty_cache() weight = weight.to(device=target_device) gguf_cls, gguf_type, gguf_real_shape = None, None, None if hasattr(weight, 'is_gguf'): from backend.operations_gguf import dequantize_tensor gguf_cls = weight.gguf_cls gguf_type = weight.gguf_type gguf_real_shape = weight.gguf_real_shape weight = dequantize_tensor(weight) try: weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) except: print('Patching LoRA weights failed. Retrying by offloading models.') self.model.to(device=offload_device) memory_management.soft_empty_cache() weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) if bnb_layer is not None: bnb_layer.reload_weight(weight) continue if gguf_cls is not None: from backend.operations_gguf import ParameterGGUF weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) utils.set_attr_raw(self.model, key, ParameterGGUF.make( data=weight, gguf_type=gguf_type, gguf_cls=gguf_cls, gguf_real_shape=gguf_real_shape )) continue utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) # Time moving_time = time.perf_counter() - execution_start_time if moving_time > 0.1: print(f'LoRA patching has taken {moving_time:.2f} seconds') return