import gguf import torch import os import json import safetensors.torch import backend.misc.checkpoint_pickle from backend.operations_gguf import ParameterGGUF def read_arbitrary_config(directory): config_path = os.path.join(directory, 'config.json') if not os.path.exists(config_path): raise FileNotFoundError(f"No config.json file found in the directory: {directory}") with open(config_path, 'rt', encoding='utf-8') as file: config_data = json.load(file) return config_data def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) elif ckpt.lower().endswith(".gguf"): reader = gguf.GGUFReader(ckpt) sd = {} for tensor in reader.tensors: sd[str(tensor.name)] = ParameterGGUF(tensor) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd return sd def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) def set_attr_raw(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], value) def copy_to_param(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) prev.data.copy_(value) def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) return obj def get_attr_with_parent(obj, attr): attrs = attr.split(".") parent = obj name = None for name in attrs: parent = obj obj = getattr(obj, name) return parent, name, obj def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): params += sd[k].nelement() return params def tensor2parameter(x): if isinstance(x, torch.nn.Parameter): return x else: return torch.nn.Parameter(x, requires_grad=False) def fp16_fix(x): # An interesting trick to avoid fp16 overflow # Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114 # Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180 if x.dtype in [torch.float16]: return x.clip(-32768.0, 32768.0) return x def nested_compute_size(obj): module_mem = 0 if isinstance(obj, dict): for key in obj: module_mem += nested_compute_size(obj[key]) elif isinstance(obj, list) or isinstance(obj, tuple): for i in range(len(obj)): module_mem += nested_compute_size(obj[i]) elif isinstance(obj, torch.Tensor): module_mem += obj.nelement() * obj.element_size() return module_mem def nested_move_to_device(obj, device): if isinstance(obj, dict): for key in obj: obj[key] = nested_move_to_device(obj[key], device) elif isinstance(obj, list): for i in range(len(obj)): obj[i] = nested_move_to_device(obj[i], device) elif isinstance(obj, tuple): obj = tuple(nested_move_to_device(i, device) for i in obj) elif isinstance(obj, torch.Tensor): return obj.to(device) return obj def get_state_dict_after_quant(model, prefix=''): for m in model.modules(): if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'): if not m.weight.bnb_quantized: original_device = m.weight.device m.cuda() m.to(original_device) sd = model.state_dict() sd = {(prefix + k): v.clone() for k, v in sd.items()} return sd