File size: 4,091 Bytes
d0ab6a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
from pathlib import Path
import json
import re
import gc
from safetensors.torch import load_file, save_file
import torch


SDXL_KEYS_FILE = "keys/sdxl_keys.txt"


def list_uniq(l):
    return sorted(set(l), key=l.index)


def read_safetensors_metadata(path: str):
    with open(path, 'rb') as f:
        header_size = int.from_bytes(f.read(8), 'little')
        header_json = f.read(header_size).decode('utf-8')
        header = json.loads(header_json)
        metadata = header.get('__metadata__', {})
        return metadata


def keys_from_file(path: str):
    keys = []
    try:
        with open(str(Path(path)), encoding='utf-8', mode='r') as f:
            lines = f.readlines()
        for line in lines:
            keys.append(line.strip())
    except Exception as e:
        print(e)
    finally:
        return keys


def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
    missing = []
    added = []
    try:
        rkeys = keys_from_file(rfile)
        all_keys = list_uniq(keys + rkeys)
        for key in all_keys:
            if key in set(rkeys) and key not in set(keys): missing.append(key)
            if key in set(keys) and key not in set(rkeys): added.append(key)
    except Exception as e:
        print(e)
    finally:
        return missing, added


def read_safetensors_key(path: str):
    try:
        keys = []
        state_dict = load_file(str(Path(path)))
        for k in list(state_dict.keys()):
            keys.append(k)
            state_dict.pop(k)
    except Exception as e:
        print(e)
    finally:
        del state_dict
        torch.cuda.empty_cache()
        gc.collect()
        return keys


def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
    if len(keys) == 0: return False
    try:
        with open(str(Path(path)), encoding='utf-8', mode='w') as f:
            f.write("\n".join(keys))
        if is_validate:
            missing, added = validate_keys(keys, rpath)
            with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
                f.write("\n".join(missing))
            with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
                f.write("\n".join(added))
        return True
    except Exception as e:
        print(e)
        return False


def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
    keys = read_safetensors_key(input)
    if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
    if len(keys) != 0:
        print("Metadata:")
        print(read_safetensors_metadata(input))
        print("\nKeys:")
        print("\n".join(keys))
        if is_validate:
            missing, added = validate_keys(keys, rfile)
            print("\nMissing Keys:")
            print("\n".join(missing))
            print("\nAdded Keys:")
            print("\n".join(added))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input", type=str, help="Input safetensors file.")
    parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
    parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
    parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
    parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")

    args = parser.parse_args()

    if args.save: out_filename = Path(args.input).stem + ".txt"
    out_filename = args.output if args.output else out_filename

    stkey(args.input, out_filename, args.val, args.rfile)


# Usage:
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s
# python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt