|
import os |
|
import re |
|
import torch |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
import hashlib |
|
from io import BytesIO |
|
import safetensors.torch |
|
from typing import Callable, Union, Optional |
|
|
|
|
|
re_digits = re.compile(r"\d+") |
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") |
|
re_compiled = {} |
|
|
|
suffix_conversion = { |
|
"attentions": {}, |
|
"resnets": { |
|
"conv1": "in_layers_2", |
|
"conv2": "out_layers_3", |
|
"time_emb_proj": "emb_layers_1", |
|
"conv_shortcut": "skip_connection", |
|
} |
|
} |
|
|
|
|
|
def convert_diffusers_name_to_compvis(key, is_sd2): |
|
def match(match_list, regex_text): |
|
regex = re_compiled.get(regex_text) |
|
if regex is None: |
|
regex = re.compile(regex_text) |
|
re_compiled[regex_text] = regex |
|
|
|
r = re.match(regex, key) |
|
if not r: |
|
return False |
|
|
|
match_list.clear() |
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
|
return True |
|
|
|
m = [] |
|
|
|
if match(m, r"lora_unet_conv_in(.*)"): |
|
return f'diffusion_model_input_blocks_0_0{m[0]}' |
|
|
|
if match(m, r"lora_unet_conv_out(.*)"): |
|
return f'diffusion_model_out_2{m[0]}' |
|
|
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): |
|
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) |
|
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): |
|
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): |
|
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" |
|
|
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): |
|
if is_sd2: |
|
if 'mlp_fc1' in m[1]: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
|
elif 'mlp_fc2' in m[1]: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
|
else: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
|
|
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
|
|
|
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): |
|
if 'mlp_fc1' in m[1]: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
|
elif 'mlp_fc2' in m[1]: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
|
else: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
|
|
|
return key |
|
|
|
def safetensors_hashes(tensors, metadata): |
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to |
|
save time on indexing the model later.""" |
|
|
|
|
|
|
|
|
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
|
bytes = safetensors.torch.save(tensors, metadata) |
|
b = BytesIO(bytes) |
|
|
|
model_hash = addnet_hash_safetensors(b) |
|
legacy_hash = addnet_hash_legacy(b) |
|
return model_hash, legacy_hash |
|
|
|
|
|
def addnet_hash_legacy(b): |
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
m = hashlib.sha256() |
|
|
|
b.seek(0x100000) |
|
m.update(b.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
|
|
|
|
def addnet_hash_safetensors(b): |
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
b.seek(0) |
|
header = b.read(8) |
|
n = int.from_bytes(header, "little") |
|
|
|
offset = n + 8 |
|
b.seek(offset) |
|
for chunk in iter(lambda: b.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
|
|
|
|
def lbw_lora(input_, output, ratios): |
|
print("Apply LBW") |
|
|
|
assert isinstance(input_, str) |
|
assert isinstance(output, str) |
|
assert isinstance(ratios, str) |
|
assert os.path.exists(input_), f"{input_} is not exists" |
|
assert os.path.exists(output) == False, f"{output} aleady exists" |
|
|
|
LOAD_PATH = input_ |
|
SAVE_PATH = output |
|
RATIOS = [float(x) for x in ratios.split(",")] |
|
LAYERS = len(RATIOS) |
|
assert LAYERS in [17, 26] |
|
|
|
BLOCKID17 = [ |
|
"BASE", "IN01", "IN02", "IN04", "IN05", "IN07", "IN08", "M00", |
|
"OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"] |
|
BLOCKID26 = [ |
|
"BASE", "IN00", "IN01", "IN02", "IN03", "IN04", "IN05", "IN06", "IN07", "IN08", "IN09", "IN10", "IN11", "M00", |
|
"OUT00", "OUT01", "OUT02", "OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"] |
|
|
|
if LAYERS == 17: |
|
RATIO_OF_ = dict(zip(BLOCKID17, RATIOS)) |
|
if LAYERS == 26: |
|
RATIO_OF_ = dict(zip(BLOCKID26, RATIOS)) |
|
print(RATIO_OF_) |
|
|
|
PATTERNS = [ |
|
r"^transformer_text_model_(encoder)_layers_(\d+)_.*", |
|
r"^diffusion_model_(in)put_blocks_(\d+)_.*", |
|
r"^diffusion_model_(middle)_block_(\d+)_.*", |
|
r"^diffusion_model_(out)put_blocks_(\d+)_.*"] |
|
|
|
def replacement(match): |
|
g1 = str(match.group(1)) |
|
g2 = int(match.group(2)) |
|
assert g1 in ["encoder", "in", "middle", "out"] |
|
assert isinstance(g2, int) |
|
|
|
if g1 == "encoder": |
|
return "BASE" |
|
if g1 == "middle": |
|
return "M00" |
|
return f"{str.upper(g1)}{g2:02}" |
|
|
|
def compvis_name_to_blockid(compvis_name): |
|
strings = compvis_name |
|
for pattern in PATTERNS: |
|
strings = re.sub(pattern, replacement, strings) |
|
if strings != compvis_name: |
|
break |
|
assert strings != compvis_name |
|
blockid = strings |
|
|
|
if LAYERS == 17: |
|
assert blockid in BLOCKID26, f"Incorrect layer {blockid}" |
|
assert blockid in BLOCKID17, f"{blockid} is not included in 17 layers. May be 26 layers?" |
|
if LAYERS == 26: |
|
assert blockid in BLOCKID26, f"Incorrect layer {blockid}" |
|
return blockid |
|
|
|
with safe_open(LOAD_PATH, framework="pt", device="cpu") as f: |
|
tensors = {} |
|
for key in f.keys(): |
|
tensors[key] = f.get_tensor(key) |
|
compvis_name = convert_diffusers_name_to_compvis(key, is_sd2=False) |
|
blockid = compvis_name_to_blockid(compvis_name) |
|
if compvis_name.endswith("lora_up.weight"): |
|
tensors[key] *= RATIO_OF_[blockid] |
|
print(f"({blockid}) {compvis_name} " |
|
f"updated with factor {RATIO_OF_[blockid]}") |
|
|
|
save_file(tensors, SAVE_PATH) |
|
|
|
print("Done") |
|
|