Spaces:
Runtime error
Runtime error
File size: 4,515 Bytes
6831a54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gguf
import torch
import os
import json
import safetensors.torch
import backend.misc.checkpoint_pickle
from backend.operations_gguf import ParameterGGUF
def read_arbitrary_config(directory):
config_path = os.path.join(directory, 'config.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config.json file found in the directory: {directory}")
with open(config_path, 'rt', encoding='utf-8') as file:
config_data = json.load(file)
return config_data
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
elif ckpt.lower().endswith(".gguf"):
reader = gguf.GGUFReader(ckpt)
sd = {}
for tensor in reader.tensors:
sd[str(tensor.name)] = ParameterGGUF(tensor)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
def set_attr_raw(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
def copy_to_param(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def get_attr_with_parent(obj, attr):
attrs = attr.split(".")
parent = obj
name = None
for name in attrs:
parent = obj
obj = getattr(obj, name)
return parent, name, obj
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def tensor2parameter(x):
if isinstance(x, torch.nn.Parameter):
return x
else:
return torch.nn.Parameter(x, requires_grad=False)
def fp16_fix(x):
# An interesting trick to avoid fp16 overflow
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
if x.dtype in [torch.float16]:
return x.clip(-32768.0, 32768.0)
return x
def nested_compute_size(obj):
module_mem = 0
if isinstance(obj, dict):
for key in obj:
module_mem += nested_compute_size(obj[key])
elif isinstance(obj, list) or isinstance(obj, tuple):
for i in range(len(obj)):
module_mem += nested_compute_size(obj[i])
elif isinstance(obj, torch.Tensor):
module_mem += obj.nelement() * obj.element_size()
return module_mem
def nested_move_to_device(obj, device):
if isinstance(obj, dict):
for key in obj:
obj[key] = nested_move_to_device(obj[key], device)
elif isinstance(obj, list):
for i in range(len(obj)):
obj[i] = nested_move_to_device(obj[i], device)
elif isinstance(obj, tuple):
obj = tuple(nested_move_to_device(i, device) for i in obj)
elif isinstance(obj, torch.Tensor):
return obj.to(device)
return obj
def get_state_dict_after_quant(model, prefix=''):
for m in model.modules():
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'):
if not m.weight.bnb_quantized:
original_device = m.weight.device
m.cuda()
m.to(original_device)
sd = model.state_dict()
sd = {(prefix + k): v.clone() for k, v in sd.items()}
return sd
|