Spaces:
Running
Running
import argparse | |
from pathlib import Path | |
import json | |
import re | |
import gc | |
from safetensors.torch import load_file, save_file | |
import torch | |
SDXL_KEYS_FILE = "keys/sdxl_keys.txt" | |
def list_uniq(l): | |
return sorted(set(l), key=l.index) | |
def read_safetensors_metadata(path: str): | |
with open(path, 'rb') as f: | |
header_size = int.from_bytes(f.read(8), 'little') | |
header_json = f.read(header_size).decode('utf-8') | |
header = json.loads(header_json) | |
metadata = header.get('__metadata__', {}) | |
return metadata | |
def keys_from_file(path: str): | |
keys = [] | |
try: | |
with open(str(Path(path)), encoding='utf-8', mode='r') as f: | |
lines = f.readlines() | |
for line in lines: | |
keys.append(line.strip()) | |
except Exception as e: | |
print(e) | |
finally: | |
return keys | |
def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE): | |
missing = [] | |
added = [] | |
try: | |
rkeys = keys_from_file(rfile) | |
all_keys = list_uniq(keys + rkeys) | |
for key in all_keys: | |
if key in set(rkeys) and key not in set(keys): missing.append(key) | |
if key in set(keys) and key not in set(rkeys): added.append(key) | |
except Exception as e: | |
print(e) | |
finally: | |
return missing, added | |
def read_safetensors_key(path: str): | |
try: | |
keys = [] | |
state_dict = load_file(str(Path(path))) | |
for k in list(state_dict.keys()): | |
keys.append(k) | |
state_dict.pop(k) | |
except Exception as e: | |
print(e) | |
finally: | |
del state_dict | |
torch.cuda.empty_cache() | |
gc.collect() | |
return keys | |
def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE): | |
if len(keys) == 0: return False | |
try: | |
with open(str(Path(path)), encoding='utf-8', mode='w') as f: | |
f.write("\n".join(keys)) | |
if is_validate: | |
missing, added = validate_keys(keys, rpath) | |
with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f: | |
f.write("\n".join(missing)) | |
with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f: | |
f.write("\n".join(added)) | |
return True | |
except Exception as e: | |
print(e) | |
return False | |
def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE): | |
keys = read_safetensors_key(input) | |
if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile) | |
if len(keys) != 0: | |
print("Metadata:") | |
print(read_safetensors_metadata(input)) | |
print("\nKeys:") | |
print("\n".join(keys)) | |
if is_validate: | |
missing, added = validate_keys(keys, rfile) | |
print("\nMissing Keys:") | |
print("\n".join(missing)) | |
print("\nAdded Keys:") | |
print("\n".join(added)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input", type=str, help="Input safetensors file.") | |
parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.") | |
parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.") | |
parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.") | |
parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.") | |
args = parser.parse_args() | |
if args.save: out_filename = Path(args.input).stem + ".txt" | |
out_filename = args.output if args.output else out_filename | |
stkey(args.input, out_filename, args.val, args.rfile) | |
# Usage: | |
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors | |
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s | |
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt | |