File size: 4,424 Bytes
34097e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import io
import os
import mmap
import torch
import json
import hashlib
import safetensors
import safetensors.torch
from modules import sd_models
# PyTorch 1.13 and later have _UntypedStorage renamed to UntypedStorage
UntypedStorage = torch.storage.UntypedStorage if hasattr(torch.storage, 'UntypedStorage') else torch.storage._UntypedStorage
def read_metadata(filename):
"""Reads the JSON metadata from a .safetensors file"""
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
return metadata.get("__metadata__", {})
def load_file(filename, device):
""""Loads a .safetensors file without memory mapping that locks the model file.
Works around safetensors issue: https://github.com/huggingface/safetensors/issues/164"""
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = UntypedStorage.from_file(filename, False, size)
offset = n + 8
md = metadata.get("__metadata__", {})
return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}, md
def hash_file(filename):
"""Hashes a .safetensors file using the new hashing method.
Only hashes the weights of the model."""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
with open(filename, mode="rb") as file_obj:
offset = n + 8
file_obj.seek(offset)
for chunk in iter(lambda: file_obj.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def legacy_hash_file(filename):
"""Hashes a model file using the legacy `sd_models.model_hash()` method."""
hash_sha256 = hashlib.sha256()
metadata = read_metadata(filename)
# For compatibility with legacy models: This replicates the behavior of
# sd_models.model_hash as if there were no user-specified metadata in the
# .safetensors file. That leaves the training parameters, which are
# immutable. It is important the hash does not include the embedded user
# metadata as that would mean the hash could change every time the user
# updates the name/description/etc. The new hashing method fixes this
# problem by only hashing the region of the file containing the tensors.
if any(not k.startswith("ss_") for k in metadata):
# Strip the user metadata, re-serialize the file as if it were freshly
# created from sd-scripts, and hash that with model_hash's behavior.
tensors, metadata = load_file(filename, "cpu")
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
model_bytes = safetensors.torch.save(tensors, metadata)
hash_sha256.update(model_bytes[0x100000:0x110000])
return hash_sha256.hexdigest()[0:8]
else:
# This should work fine with model_hash since when the legacy hashing
# method was being used the user metadata system hadn't been implemented
# yet.
return sd_models.model_hash(filename)
DTYPES = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
# "U64": torch.uint64,
"I32": torch.int32,
# "U32": torch.uint32,
"I16": torch.int16,
# "U16": torch.uint16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool
}
def create_tensor(storage, info, offset):
"""Creates a tensor without holding on to an open handle to the parent model
file."""
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape).clone().detach()
|