|
import hashlib |
|
from io import BytesIO |
|
from typing import Optional |
|
|
|
import safetensors.torch |
|
import torch |
|
|
|
|
|
def model_hash(filename): |
|
"""Old model hash used by stable-diffusion-webui""" |
|
try: |
|
with open(filename, "rb") as file: |
|
m = hashlib.sha256() |
|
|
|
file.seek(0x100000) |
|
m.update(file.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
except FileNotFoundError: |
|
return "NOFILE" |
|
except IsADirectoryError: |
|
return "IsADirectory" |
|
except PermissionError: |
|
return "IsADirectory" |
|
|
|
|
|
def calculate_sha256(filename): |
|
"""New model hash used by stable-diffusion-webui""" |
|
try: |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
with open(filename, "rb") as f: |
|
for chunk in iter(lambda: f.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
except FileNotFoundError: |
|
return "NOFILE" |
|
except IsADirectoryError: |
|
return "IsADirectory" |
|
except PermissionError: |
|
return "IsADirectory" |
|
|
|
|
|
def addnet_hash_legacy(b): |
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
m = hashlib.sha256() |
|
|
|
b.seek(0x100000) |
|
m.update(b.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
|
|
|
|
def addnet_hash_safetensors(b): |
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
b.seek(0) |
|
header = b.read(8) |
|
n = int.from_bytes(header, "little") |
|
|
|
offset = n + 8 |
|
b.seek(offset) |
|
for chunk in iter(lambda: b.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
|
|
|
|
def precalculate_safetensors_hashes(tensors, metadata): |
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to |
|
save time on indexing the model later.""" |
|
|
|
|
|
|
|
|
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
|
bytes = safetensors.torch.save(tensors, metadata) |
|
b = BytesIO(bytes) |
|
|
|
model_hash = addnet_hash_safetensors(b) |
|
legacy_hash = addnet_hash_legacy(b) |
|
return model_hash, legacy_hash |
|
|
|
|
|
def dtype_to_str(dtype: torch.dtype) -> str: |
|
|
|
dtype_name = str(dtype).split(".")[-1] |
|
return dtype_name |
|
|
|
|
|
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: |
|
""" |
|
Convert a string to a torch.dtype |
|
|
|
Args: |
|
s: string representation of the dtype |
|
default_dtype: default dtype to return if s is None |
|
|
|
Returns: |
|
torch.dtype: the corresponding torch.dtype |
|
|
|
Raises: |
|
ValueError: if the dtype is not supported |
|
|
|
Examples: |
|
>>> str_to_dtype("float32") |
|
torch.float32 |
|
>>> str_to_dtype("fp32") |
|
torch.float32 |
|
>>> str_to_dtype("float16") |
|
torch.float16 |
|
>>> str_to_dtype("fp16") |
|
torch.float16 |
|
>>> str_to_dtype("bfloat16") |
|
torch.bfloat16 |
|
>>> str_to_dtype("bf16") |
|
torch.bfloat16 |
|
>>> str_to_dtype("fp8") |
|
torch.float8_e4m3fn |
|
>>> str_to_dtype("fp8_e4m3fn") |
|
torch.float8_e4m3fn |
|
>>> str_to_dtype("fp8_e4m3fnuz") |
|
torch.float8_e4m3fnuz |
|
>>> str_to_dtype("fp8_e5m2") |
|
torch.float8_e5m2 |
|
>>> str_to_dtype("fp8_e5m2fnuz") |
|
torch.float8_e5m2fnuz |
|
""" |
|
if s is None: |
|
return default_dtype |
|
if s in ["bf16", "bfloat16"]: |
|
return torch.bfloat16 |
|
elif s in ["fp16", "float16"]: |
|
return torch.float16 |
|
elif s in ["fp32", "float32", "float"]: |
|
return torch.float32 |
|
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: |
|
return torch.float8_e4m3fn |
|
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: |
|
return torch.float8_e4m3fnuz |
|
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: |
|
return torch.float8_e5m2 |
|
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: |
|
return torch.float8_e5m2fnuz |
|
elif s in ["fp8", "float8"]: |
|
return torch.float8_e4m3fn |
|
else: |
|
raise ValueError(f"Unsupported dtype: {s}") |
|
|