CM2000112 / external /llite /library /sai_model_spec.py
jayparmr's picture
Upload folder using huggingface_hub
ea5c647
raw
history blame
9.49 kB
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
r"""
# Metadata Example
metadata = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
"modelspec.implementation": "sgm",
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
# === Should ===
"modelspec.author": "Example Corp", # Your name or company name
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
# === Can ===
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
}
"""
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.license": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
}
# 別に使うやつだけ定義
MODELSPEC_TITLE = "modelspec.title"
ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata(
state_dict: Optional[dict],
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
return metadata
# region utils
def get_title(metadata: dict) -> Optional[str]:
return metadata.get(MODELSPEC_TITLE, None)
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
def build_merged_from(models: List[str]) -> str:
def get_title(model: str):
metadata = load_metadata_from_safetensors(model)
title = metadata.get(MODELSPEC_TITLE, None)
if title is None:
title = os.path.splitext(os.path.basename(model))[0] # use filename
return title
titles = [get_title(model) for model in models]
return ", ".join(titles)
# endregion
r"""
if __name__ == "__main__":
import argparse
import torch
from safetensors.torch import load_file
from library import train_util
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
args = parser.parse_args()
print(f"Loading {args.ckpt}")
state_dict = load_file(args.ckpt)
print(f"Calculating metadata")
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
print(metadata)
del state_dict
# by reference implementation
with open(args.ckpt, mode="rb") as file_data:
file_hash = hashlib.sha256()
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
header = json.loads(file_data.read(head_len[0])) # header itself, json string
content = (
file_data.read()
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
file_hash.update(content)
# ===== Update the hash for modelspec =====
by_ref = f"0x{file_hash.hexdigest()}"
print(by_ref)
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
"""