import re |
from sklearn.linear_model import PassiveAggressiveClassifier |
import torch |
import math |
import os |
import gc |
import gradio as gr |
from torchmetrics import Precision |
import modules.shared as shared |
import gc |
from safetensors.torch import load_file, save_file |
from typing import List |
from tqdm import tqdm |
from modules import sd_models,scripts |
from scripts.mergers.model_util import load_models_from_stable_diffusion_checkpoint,filenamecutter,savemodel |
from modules.ui import create_refresh_button |
def on_ui_tabs(): |
import lora |
sml_path_root = scripts.basedir() |
NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\ |
INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\ |
OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\ |
OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\ |
OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\ |
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5" |
sml_filepath = os.path.join(sml_path_root,"scripts", "lbwpresets.txt") |
sml_lbwpresets="" |
try: |
with open(sml_filepath,encoding="utf-8") as f: |
sml_lbwpresets = f.read() |
except OSError as e: |
sml_lbwpresets=LWEIGHTSPRESETS |
with gr.Blocks(analytics_enabled=False) : |
sml_submit_result = gr.Textbox(label="Message") |
with gr.Row().style(equal_height=False): |
sml_cpmerge = gr.Button(elem_id="model_merger_merge", value="Merge to Checkpoint",variant='primary') |
sml_makelora = gr.Button(elem_id="model_merger_merge", value="Make LoRA (alpha * A - beta * B)",variant='primary') |
sml_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True) |
create_refresh_button(sml_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
sml_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True) |
create_refresh_button(sml_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
with gr.Row().style(equal_height=False): |
sml_merge = gr.Button(elem_id="model_merger_merge", value="Merge LoRAs",variant='primary') |
alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=1) |
beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=1) |
with gr.Row().style(equal_height=False): |
sml_settings = gr.CheckboxGroup(["same to Strength", "overwrite"], label="settings") |
precision = gr.Radio(label = "save precision",choices=["float","fp16","bf16"],value = "fp16",type="value") |
with gr.Row().style(equal_height=False): |
sml_dim = gr.Radio(label = "remake dimension",choices = ["no","auto",*[2**(x+2) for x in range(9)]],value = "no",type = "value") |
sml_filename = gr.Textbox(label="filename(option)",lines=1,visible =True,interactive = True) |
sml_loranames = gr.Textbox(label='LoRAname1:ratio1:Blocks1,LoRAname2:ratio2:Blocks2,...(":blocks" is option, not necessary)',lines=1,value="",visible =True) |
sml_dims = gr.CheckboxGroup(label = "limit dimension",choices=[],value = [],type="value",interactive=True,visible = False) |
with gr.Row().style(equal_height=False): |
sml_calcdim = gr.Button(elem_id="calcloras", value="calculate dimension of LoRAs(It may take a few minutes if there are many LoRAs)",variant='primary') |
sml_update = gr.Button(elem_id="calcloras", value="update list",variant='primary') |
sml_loras = gr.CheckboxGroup(label = "Lora",choices=[x[0] for x in lora.available_loras.items()],type="value",interactive=True,visible = True) |
sml_loraratios = gr.TextArea(label="",value=sml_lbwpresets,visible =True,interactive = True) |
sml_merge.click( |
fn=lmerge, |
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_dim,precision], |
outputs=[sml_submit_result] |
) |
sml_makelora.click( |
fn=makelora, |
inputs=[sml_model_a,sml_model_b,sml_dim,sml_filename,sml_settings,alpha,beta,precision], |
outputs=[sml_submit_result] |
) |
sml_cpmerge.click( |
fn=pluslora, |
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_model_a,precision], |
outputs=[sml_submit_result] |
) |
llist ={} |
dlist =[] |
dn = [] |
def updateloras(): |
lora.list_available_loras() |
for n in lora.available_loras.items(): |
if n[0] not in llist:llist[n[0]] = "" |
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
sml_update.click(fn = updateloras,outputs = [sml_loras]) |
def calculatedim(): |
print("listing dimensions...") |
for n in tqdm(lora.available_loras.items()): |
if n[0] in llist: |
if llist[n[0]] !="": continue |
c_lora = lora.available_loras.get(n[0], None) |
d,t = dimgetter(c_lora.filename) |
if t == "LoCon" : d = f"{d}:{t}" |
if d not in dlist: |
if type(d) == int :dlist.append(d) |
elif d not in dn: dn.append(d) |
llist[n[0]] = d |
dlist.sort() |
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()],value =[]),gr.update(visible =True,choices = [x for x in (dlist+dn)]) |
sml_calcdim.click( |
fn=calculatedim, |
inputs=[], |
outputs=[sml_loras,sml_dims] |
) |
def dimselector(dims): |
if dims ==[]:return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
rl=[] |
for d in dims: |
for i in llist.items(): |
if d == i[1]:rl.append(f"{i[0]}({i[1]})") |
return gr.update(choices = [l for l in rl],value =[]) |
def llister(names): |
if names ==[] : return "" |
else: |
for i,n in enumerate(names): |
if "(" in n:names[i] = n[:n.rfind("(")] |
return ":1.0,".join(names)+":1.0" |
sml_loras.change(fn=llister,inputs=[sml_loras],outputs=[sml_loranames]) |
sml_dims.change(fn=dimselector,inputs=[sml_dims],outputs=[sml_loras]) |
def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,precision): |
print("make LoRA start") |
if model_a == "" or model_b =="": |
return "ERROR: No model Selected" |
gc.collect() |
if saveto =="" : saveto = makeloraname(model_a,model_b) |
if not ".safetensors" in saveto :saveto += ".safetensors" |
saveto = os.path.join(shared.cmd_opts.lora_dir,saveto) |
dim = 128 if type(dim) != int else int(dim) |
if os.path.isfile(saveto ) and not "overwrite" in settings: |
_err_msg = f"Output file ({saveto}) existed and was not saved" |
print(_err_msg) |
return _err_msg |
svd(fullpathfromname(model_a),fullpathfromname(model_b),False,dim,precision,saveto,alpha,beta) |
return f"saved to {saveto}" |
def lmerge(loranames,loraratioss,settings,filename,dim,precision): |
import lora |
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
if any([x is None for x in loras_on_disk]): |
lora.list_available_loras() |
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
lnames = [loranames] if "," not in loranames else loranames.split(",") |
for i, n in enumerate(lnames): |
lnames[i] = n.split(":") |
loraratios=loraratioss.splitlines() |
ldict ={} |
for i,l in enumerate(loraratios): |
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
ldict[l.split(":")[0]]=l.split(":")[1] |
ln = [] |
lr = [] |
ld = [] |
lt = [] |
dmax = 1 |
for i,n in enumerate(lnames): |
if len(n) ==3: |
if n[2].strip() in ldict: |
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
else:ratio = [float(n[1])]*17 |
else:ratio = [float(n[1])]*17 |
c_lora = lora.available_loras.get(n[0], None) |
ln.append(c_lora.filename) |
lr.append(ratio) |
d,t = dimgetter(c_lora.filename) |
lt.append(t) |
ld.append(d) |
if d != "LyCORIS": |
if d > dmax : dmax = d |
if filename =="":filename =loranames.replace(",","+").replace(":","_") |
if not ".safetensors" in filename:filename += ".safetensors" |
filename = os.path.join(shared.cmd_opts.lora_dir,filename) |
dim = int(dim) if dim != "no" and dim != "auto" else 0 |
if "LyCORIS" in ld or "LoCon" in lt: |
if len(ld) !=1: |
return "multiple merge of LyCORIS is not supported" |
sd = lycomerge(ln[0],lr[0]) |
elif dim > 0: |
print("change demension to ", dim) |
sd = merge_lora_models_dim(ln, lr, dim,settings) |
elif "auto" in settings and ld.count(ld[0]) != len(ld): |
print("change demension to ",dmax) |
sd = merge_lora_models_dim(ln, lr, dmax,settings) |
else: |
sd = merge_lora_models(ln, lr,settings) |
if os.path.isfile(filename) and not "overwrite" in settings: |
_err_msg = f"Output file ({filename}) existed and was not saved" |
print(_err_msg) |
return _err_msg |
save_to_file(filename,sd,sd, str_to_dtype(precision)) |
return "saved : "+filename |
def pluslora(lnames,loraratios,settings,output,model,precision): |
if model == []: |
return "ERROR: No model Selected" |
if lnames == "": |
return "ERROR: No LoRA Selected" |
print("plus LoRA start") |
import lora |
lnames = [lnames] if "," not in lnames else lnames.split(",") |
for i, n in enumerate(lnames): |
lnames[i] = n.split(":") |
loraratios=loraratios.splitlines() |
ldict ={} |
for i,l in enumerate(loraratios): |
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
ldict[l.split(":")[0].strip()]=l.split(":")[1] |
names=[] |
filenames=[] |
loratypes=[] |
lweis=[] |
for n in lnames: |
if len(n) ==3: |
if n[2].strip() in ldict: |
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
else:ratio = [float(n[1])]*17 |
else:ratio = [float(n[1])]*17 |
c_lora = lora.available_loras.get(n[0], None) |
names.append(n[0]) |
filenames.append(c_lora.filename) |
_,t = dimgetter(c_lora.filename) |
if "LyCORIS" in t: return "LyCORIS merge is not supported" |
lweis.append(ratio) |
modeln=filenamecutter(model,True) |
dname = modeln |
for n in names: |
dname = dname + "+"+n |
checkpoint_info = sd_models.get_closet_checkpoint_match(model) |
print(f"Loading {model}") |
theta_0 = sd_models.read_state_dict(checkpoint_info.filename,"cpu") |
keychanger = {} |
for key in theta_0.keys(): |
if "model" in key: |
skey = key.replace(".","_").replace("_weight","") |
keychanger[skey.split("model_",1)[1]] = key |
for name,filename, lwei in zip(names,filenames, lweis): |
print(f"loading: {name}") |
lora_sd = load_state_dict(filename, torch.float) |
print(f"merging..." ,lwei) |
for key in lora_sd.keys(): |
ratio = 1 |
fullkey = convert_diffusers_name_to_compvis(key) |
for i,block in enumerate(LORABLOCKS): |
if block in fullkey: |
ratio = lwei[i] |
msd_key, lora_key = fullkey.split(".", 1) |
if "lora_down" in key: |
up_key = key.replace("lora_down", "lora_up") |
alpha_key = key[:key.index("lora_down")] + 'alpha' |
down_weight = lora_sd[key].to(device="cpu") |
up_weight = lora_sd[up_key].to(device="cpu") |
dim = down_weight.size()[0] |
alpha = lora_sd.get(alpha_key, dim) |
scale = alpha / dim |
weight = theta_0[keychanger[msd_key]].to(device="cpu") |
if not len(down_weight.size()) == 4: |
weight = weight + ratio * (up_weight @ down_weight) * scale |
else: |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
).unsqueeze(2).unsqueeze(3) * scale |
theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight) |
settings.append(precision) |
result = savemodel(theta_0,dname,output,settings,model) |
del theta_0 |
gc.collect() |
return result |
def save_to_file(file_name, model, state_dict, dtype): |
if dtype is not None: |
for key in list(state_dict.keys()): |
if type(state_dict[key]) == torch.Tensor: |
state_dict[key] = state_dict[key].to(dtype) |
if os.path.splitext(file_name)[1] == '.safetensors': |
save_file(model, file_name) |
else: |
torch.save(model, file_name) |
re_digits = re.compile(r"\d+") |
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") |
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") |
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") |
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)") |
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)") |
re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)") |
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)") |
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)") |
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") |
def convert_diffusers_name_to_compvis(key): |
def match(match_list, regex): |
r = re.match(regex, key) |
if not r: |
return False |
match_list.clear() |
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
return True |
m = [] |
if match(m, re_unet_down_blocks): |
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" |
if match(m, re_unet_mid_blocks): |
return f"diffusion_model_middle_block_1_{m[1]}" |
if match(m, re_unet_up_blocks): |
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" |
if match(m, re_unet_down_blocks_res): |
block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_" |
if m[2].startswith('conv1'): |
return f"{block}in_layers_2{m[2][len('conv1'):]}" |
elif m[2].startswith('conv2'): |
return f"{block}out_layers_3{m[2][len('conv2'):]}" |
elif m[2].startswith('time_emb_proj'): |
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
elif m[2].startswith('conv_shortcut'): |
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
if match(m, re_unet_mid_blocks_res): |
block = f"diffusion_model_middle_block_{m[0]*2}_" |
if m[1].startswith('conv1'): |
return f"{block}in_layers_2{m[1][len('conv1'):]}" |
elif m[1].startswith('conv2'): |
return f"{block}out_layers_3{m[1][len('conv2'):]}" |
elif m[1].startswith('time_emb_proj'): |
return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}" |
elif m[1].startswith('conv_shortcut'): |
return f"{block}skip_connection{m[1][len('conv_shortcut'):]}" |
if match(m, re_unet_up_blocks_res): |
block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_" |
if m[2].startswith('conv1'): |
return f"{block}in_layers_2{m[2][len('conv1'):]}" |
elif m[2].startswith('conv2'): |
return f"{block}out_layers_3{m[2][len('conv2'):]}" |
elif m[2].startswith('time_emb_proj'): |
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
elif m[2].startswith('conv_shortcut'): |
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
if match(m, re_unet_downsample): |
return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}" |
if match(m, re_unet_upsample): |
return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}" |
if match(m, re_text_block): |
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
return key |
MIN_DIFF = 1e-6 |
def str_to_dtype(p): |
if p == 'float': |
return torch.float |
if p == 'fp16': |
return torch.float16 |
if p == 'bf16': |
return torch.bfloat16 |
return None |
def svd(model_a,model_b,v2,dim,save_precision,save_to,alpha,beta): |
save_dtype = str_to_dtype(save_precision) |
if model_a == model_b: |
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
text_encoder_o, _, unet_o = text_encoder_t, _, unet_t |
else: |
print(f"loading SD model : {model_b}") |
text_encoder_o, _, unet_o = load_models_from_stable_diffusion_checkpoint(v2, model_b) |
print(f"loading SD model : {model_a}") |
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
lora_network_o = create_network(1.0, dim, dim, None, text_encoder_o, unet_o) |
lora_network_t = create_network(1.0, dim, dim, None, text_encoder_t, unet_t) |
assert len(lora_network_o.text_encoder_loras) == len( |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " |
diffs = {} |
text_encoder_different = False |
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): |
lora_name = lora_o.lora_name |
module_o = lora_o.org_module |
module_t = lora_t.org_module |
diff = alpha*module_t.weight - beta*module_o.weight |
if torch.max(torch.abs(diff)) > MIN_DIFF: |
text_encoder_different = True |
diff = diff.float() |
diffs[lora_name] = diff |
if not text_encoder_different: |
print("Text encoder is same. Extract U-Net only.") |
lora_network_o.text_encoder_loras = [] |
diffs = {} |
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): |
lora_name = lora_o.lora_name |
module_o = lora_o.org_module |
module_t = lora_t.org_module |
diff = alpha*module_t.weight - beta*module_o.weight |
diff = diff.float() |
diffs[lora_name] = diff |
print("calculating by svd") |
rank = dim |
lora_weights = {} |
with torch.no_grad(): |
for lora_name, mat in tqdm(list(diffs.items())): |
conv2d = (len(mat.size()) == 4) |
if conv2d: |
mat = mat.squeeze() |
U, S, Vh = torch.linalg.svd(mat) |
U = U[:, :rank] |
S = S[:rank] |
U = U @ torch.diag(S) |
Vh = Vh[:rank, :] |
dist = torch.cat([U.flatten(), Vh.flatten()]) |
hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
low_val = -hi_val |
U = U.clamp(low_val, hi_val) |
Vh = Vh.clamp(low_val, hi_val) |
lora_weights[lora_name] = (U, Vh) |
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) |
lora_sd = lora_network_o.state_dict() |
print(f"LoRA has {len(lora_sd)} weights.") |
for key in list(lora_sd.keys()): |
if "alpha" in key: |
continue |
lora_name = key.split('.')[0] |
i = 0 if "lora_up" in key else 1 |
weights = lora_weights[lora_name][i] |
if len(lora_sd[key].size()) == 4: |
weights = weights.unsqueeze(2).unsqueeze(3) |
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" |
lora_sd[key] = weights |
info = lora_network_o.load_state_dict(lora_sd) |
print(f"Loading extracted LoRA weights: {info}") |
dir_name = os.path.dirname(save_to) |
if dir_name and not os.path.exists(dir_name): |
os.makedirs(dir_name, exist_ok=True) |
metadata = {"ss_network_dim": str(dim), "ss_network_alpha": str(dim)} |
lora_network_o.save_weights(save_to, save_dtype, metadata) |
print(f"LoRA weights are saved to: {save_to}") |
return save_to |
def load_state_dict(file_name, dtype): |
if os.path.splitext(file_name)[1] == '.safetensors': |
sd = load_file(file_name) |
else: |
sd = torch.load(file_name, map_location='cpu') |
for key in list(sd.keys()): |
if type(sd[key]) == torch.Tensor: |
sd[key] = sd[key].to(dtype) |
return sd |
def dimgetter(filename): |
lora_sd = load_state_dict(filename, torch.float) |
alpha = None |
dim = None |
type = None |
if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys(): |
type = "LoCon" |
for key, value in lora_sd.items(): |
if alpha is None and 'alpha' in key: |
alpha = value |
if dim is None and 'lora_down' in key and len(value.size()) == 2: |
dim = value.size()[0] |
if "hada_" in key: |
dim,type = "LyCORIS","LyCORIS" |
if alpha is not None and dim is not None: |
break |
if alpha is None: |
alpha = dim |
if type == None:type = "LoRA" |
if dim : |
return dim,type |
else: |
return "unknown","unknown" |
def blockfromkey(key): |
fullkey = convert_diffusers_name_to_compvis(key) |
for i,n in enumerate(LORABLOCKS): |
if n in fullkey: return i |
return 0 |
def merge_lora_models_dim(models, ratios, new_rank,sets): |
merged_sd = {} |
fugou = 1 |
for model, ratios in zip(models, ratios): |
merge_dtype = torch.float |
lora_sd = load_state_dict(model, merge_dtype) |
print(f"merging {model}: {ratios}") |
for key in tqdm(list(lora_sd.keys())): |
if 'lora_down' not in key: |
continue |
lora_module_name = key[:key.rfind(".lora_down")] |
down_weight = lora_sd[key] |
network_dim = down_weight.size()[0] |
up_weight = lora_sd[lora_module_name + '.lora_up.weight'] |
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) |
in_dim = down_weight.size()[1] |
out_dim = up_weight.size()[0] |
conv2d = len(down_weight.size()) == 4 |
if lora_module_name not in merged_sd: |
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype) |
else: |
weight = merged_sd[lora_module_name] |
ratio = ratios[blockfromkey(key)] |
if "same to Strength" in sets: |
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
scale = (alpha / network_dim) |
if not conv2d: |
weight = weight + ratio * (up_weight @ down_weight) * scale * fugou |
else: |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
).unsqueeze(2).unsqueeze(3) * scale * fugou |
merged_sd[lora_module_name] = weight |
print("extract new lora...") |
merged_lora_sd = {} |
with torch.no_grad(): |
for lora_module_name, mat in tqdm(list(merged_sd.items())): |
conv2d = (len(mat.size()) == 4) |
if conv2d: |
mat = mat.squeeze() |
U, S, Vh = torch.linalg.svd(mat) |
U = U[:, :new_rank] |
S = S[:new_rank] |
U = U @ torch.diag(S) |
Vh = Vh[:new_rank, :] |
dist = torch.cat([U.flatten(), Vh.flatten()]) |
hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
low_val = -hi_val |
U = U.clamp(low_val, hi_val) |
Vh = Vh.clamp(low_val, hi_val) |
up_weight = U |
down_weight = Vh |
if conv2d: |
up_weight = up_weight.unsqueeze(2).unsqueeze(3) |
down_weight = down_weight.unsqueeze(2).unsqueeze(3) |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() |
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank) |
return merged_lora_sd |
def merge_lora_models(models, ratios,sets): |
base_alphas = {} |
base_dims = {} |
merge_dtype = torch.float |
merged_sd = {} |
fugou = 1 |
for model, ratios in zip(models, ratios): |
print(f"merging {model}: {ratios}") |
lora_sd = load_state_dict(model, merge_dtype) |
alphas = {} |
dims = {} |
for key in lora_sd.keys(): |
if 'alpha' in key: |
lora_module_name = key[:key.rfind(".alpha")] |
alpha = float(lora_sd[key].detach().numpy()) |
alphas[lora_module_name] = alpha |
if lora_module_name not in base_alphas: |
base_alphas[lora_module_name] = alpha |
elif "lora_down" in key: |
lora_module_name = key[:key.rfind(".lora_down")] |
dim = lora_sd[key].size()[0] |
dims[lora_module_name] = dim |
if lora_module_name not in base_dims: |
base_dims[lora_module_name] = dim |
for lora_module_name in dims.keys(): |
if lora_module_name not in alphas: |
alpha = dims[lora_module_name] |
alphas[lora_module_name] = alpha |
if lora_module_name not in base_alphas: |
base_alphas[lora_module_name] = alpha |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") |
print(f"merging...") |
for key in lora_sd.keys(): |
if 'alpha' in key: |
continue |
if "lora_down" in key: dwon = True |
lora_module_name = key[:key.rfind(".lora_")] |
base_alpha = base_alphas[lora_module_name] |
alpha = alphas[lora_module_name] |
ratio = ratios[blockfromkey(key)] |
if "same to Strength" in sets: |
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
if "lora_down" in key: |
ratio = ratio * fugou |
scale = math.sqrt(alpha / base_alpha) * ratio |
if key in merged_sd: |
assert merged_sd[key].size() == lora_sd[key].size( |
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale |
else: |
merged_sd[key] = lora_sd[key] * scale |
for lora_module_name, alpha in base_alphas.items(): |
key = lora_module_name + ".alpha" |
merged_sd[key] = torch.tensor(alpha) |
print("merged model") |
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") |
return merged_sd |
def fullpathfromname(name): |
if hash == "" or hash ==[]: return "" |
checkpoint_info = sd_models.get_closet_checkpoint_match(name) |
return checkpoint_info.filename |
def makeloraname(model_a,model_b): |
model_a=filenamecutter(model_a) |
model_b=filenamecutter(model_b) |
return "lora_"+model_a+"-"+model_b |
def lycomerge(filename,ratios): |
sd = load_state_dict(filename, torch.float) |
if len(ratios) == 17: |
r0 = 1 |
ratios = [ratios[0]] + [r0] + ratios[1:3]+ [r0] + ratios[3:5]+[r0] + ratios[5:7]+[r0,r0,r0] + [ratios[7]] + [r0,r0,r0] + ratios[8:] |
print("LyCORIS: " , ratios) |
keys_failed_to_match = [] |
for lkey, weight in sd.items(): |
ratio = 1 |
picked = False |
if 'alpha' in lkey: |
continue |
fullkey = convert_diffusers_name_to_compvis(lkey) |
key, lora_key = fullkey.split(".", 1) |
for i,block in enumerate(LYCOBLOCKS): |
if block in key: |
ratio = ratios[i] |
picked = True |
if not picked: keys_failed_to_match.append(key) |
sd[lkey] = weight * math.sqrt(abs(float(ratio))) |
if "down" in lkey and ratio < 0: |
sd[key] = sd[key] * -1 |
if len(keys_failed_to_match) > 0: |
print(keys_failed_to_match) |
return sd |
LORABLOCKS=["encoder", |
"diffusion_model_input_blocks_1_", |
"diffusion_model_input_blocks_2_", |
"diffusion_model_input_blocks_4_", |
"diffusion_model_input_blocks_5_", |
"diffusion_model_input_blocks_7_", |
"diffusion_model_input_blocks_8_", |
"diffusion_model_middle_block_", |
"diffusion_model_output_blocks_3_", |
"diffusion_model_output_blocks_4_", |
"diffusion_model_output_blocks_5_", |
"diffusion_model_output_blocks_6_", |
"diffusion_model_output_blocks_7_", |
"diffusion_model_output_blocks_8_", |
"diffusion_model_output_blocks_9_", |
"diffusion_model_output_blocks_10_", |
"diffusion_model_output_blocks_11_"] |
LYCOBLOCKS=["encoder", |
"diffusion_model_input_blocks_0_", |
"diffusion_model_input_blocks_1_", |
"diffusion_model_input_blocks_2_", |
"diffusion_model_input_blocks_3_", |
"diffusion_model_input_blocks_4_", |
"diffusion_model_input_blocks_5_", |
"diffusion_model_input_blocks_6_", |
"diffusion_model_input_blocks_7_", |
"diffusion_model_input_blocks_8_", |
"diffusion_model_input_blocks_9_", |
"diffusion_model_input_blocks_10_", |
"diffusion_model_input_blocks_11_", |
"diffusion_model_middle_block_", |
"diffusion_model_output_blocks_0_", |
"diffusion_model_output_blocks_1_", |
"diffusion_model_output_blocks_2_", |
"diffusion_model_output_blocks_3_", |
"diffusion_model_output_blocks_4_", |
"diffusion_model_output_blocks_5_", |
"diffusion_model_output_blocks_6_", |
"diffusion_model_output_blocks_7_", |
"diffusion_model_output_blocks_8_", |
"diffusion_model_output_blocks_9_", |
"diffusion_model_output_blocks_10_", |
"diffusion_model_output_blocks_11_"] |
class LoRAModule(torch.nn.Module): |
""" |
replaces forward method of the original Linear, instead of replacing the original Linear module. |
""" |
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): |
"""if alpha == 0 or None, alpha is rank (no scaling).""" |
super().__init__() |
self.lora_name = lora_name |
if org_module.__class__.__name__ == "Conv2d": |
in_dim = org_module.in_channels |
out_dim = org_module.out_channels |
else: |
in_dim = org_module.in_features |
out_dim = org_module.out_features |
self.lora_dim = lora_dim |
if org_module.__class__.__name__ == "Conv2d": |
kernel_size = org_module.kernel_size |
stride = org_module.stride |
padding = org_module.padding |
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) |
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) |
else: |
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
if type(alpha) == torch.Tensor: |
alpha = alpha.detach().float().numpy() |
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
self.scale = alpha / self.lora_dim |
self.register_buffer("alpha", torch.tensor(alpha)) |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
torch.nn.init.zeros_(self.lora_up.weight) |
self.multiplier = multiplier |
self.org_module = org_module |
self.region = None |
self.region_mask = None |
def apply_to(self): |
self.org_forward = self.org_module.forward |
self.org_module.forward = self.forward |
del self.org_module |
def merge_to(self, sd, dtype, device): |
up_weight = sd["lora_up.weight"].to(torch.float).to(device) |
down_weight = sd["lora_down.weight"].to(torch.float).to(device) |
org_sd = self.org_module.state_dict() |
weight = org_sd["weight"].to(torch.float) |
if len(weight.size()) == 2: |
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale |
elif down_weight.size()[2:4] == (1, 1): |
weight = ( |
weight |
+ self.multiplier |
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
* self.scale |
) |
else: |
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) |
weight = weight + self.multiplier * conved * self.scale |
org_sd["weight"] = weight.to(dtype) |
self.org_module.load_state_dict(org_sd) |
def set_region(self, region): |
self.region = region |
self.region_mask = None |
def forward(self, x): |
if self.region is None: |
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
if x.size()[1] % 77 == 0: |
self.region = None |
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
if self.region_mask is None: |
if len(x.size()) == 4: |
h, w = x.size()[2:4] |
else: |
seq_len = x.size()[1] |
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) |
h = int(self.region.size()[0] / ratio + 0.5) |
w = seq_len // h |
r = self.region.to(x.device) |
if r.dtype == torch.bfloat16: |
r = r.to(torch.float) |
r = r.unsqueeze(0).unsqueeze(1) |
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") |
r = r.to(x.dtype) |
if len(x.size()) == 3: |
r = torch.reshape(r, (1, x.size()[1], -1)) |
self.region_mask = r |
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): |
if network_dim is None: |
network_dim = 4 |
conv_dim = kwargs.get("conv_dim", None) |
conv_alpha = kwargs.get("conv_alpha", None) |
if conv_dim is not None: |
conv_dim = int(conv_dim) |
if conv_alpha is None: |
conv_alpha = 1.0 |
else: |
conv_alpha = float(conv_alpha) |
""" |
block_dims = kwargs.get("block_dims") |
block_alphas = None |
if block_dims is not None: |
block_dims = [int(d) for d in block_dims.split(',')] |
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
block_alphas = kwargs.get("block_alphas") |
if block_alphas is None: |
block_alphas = [1] * len(block_dims) |
else: |
block_alphas = [int(a) for a in block_alphas(',')] |
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
conv_block_dims = kwargs.get("conv_block_dims") |
conv_block_alphas = None |
if conv_block_dims is not None: |
conv_block_dims = [int(d) for d in conv_block_dims.split(',')] |
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
conv_block_alphas = kwargs.get("conv_block_alphas") |
if conv_block_alphas is None: |
conv_block_alphas = [1] * len(conv_block_dims) |
else: |
conv_block_alphas = [int(a) for a in conv_block_alphas(',')] |
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
""" |
network = LoRANetwork( |
text_encoder, |
unet, |
multiplier=multiplier, |
lora_dim=network_dim, |
alpha=network_alpha, |
conv_lora_dim=conv_dim, |
conv_alpha=conv_alpha, |
) |
return network |
class LoRANetwork(torch.nn.Module): |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
LORA_PREFIX_UNET = "lora_unet" |
def __init__( |
self, |
text_encoder, |
unet, |
multiplier=1.0, |
lora_dim=4, |
alpha=1, |
conv_lora_dim=None, |
conv_alpha=None, |
modules_dim=None, |
modules_alpha=None, |
) -> None: |
super().__init__() |
self.multiplier = multiplier |
self.lora_dim = lora_dim |
self.alpha = alpha |
self.conv_lora_dim = conv_lora_dim |
self.conv_alpha = conv_alpha |
if modules_dim is not None: |
print(f"create LoRA network from weights") |
else: |
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") |
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None |
if self.apply_to_conv2d_3x3: |
if self.conv_alpha is None: |
self.conv_alpha = self.alpha |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: |
loras = [] |
for name, module in root_module.named_modules(): |
if module.__class__.__name__ in target_replace_modules: |
for child_name, child_module in module.named_modules(): |
is_linear = child_module.__class__.__name__ == "Linear" |
is_conv2d = child_module.__class__.__name__ == "Conv2d" |
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
if is_linear or is_conv2d: |
lora_name = prefix + "." + name + "." + child_name |
lora_name = lora_name.replace(".", "_") |
if modules_dim is not None: |
if lora_name not in modules_dim: |
continue |
dim = modules_dim[lora_name] |
alpha = modules_alpha[lora_name] |
else: |
if is_linear or is_conv2d_1x1: |
dim = self.lora_dim |
alpha = self.alpha |
elif self.apply_to_conv2d_3x3: |
dim = self.conv_lora_dim |
alpha = self.conv_alpha |
else: |
continue |
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) |
loras.append(lora) |
return loras |
self.text_encoder_loras = create_modules( |
) |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE |
if modules_dim is not None or self.conv_lora_dim is not None: |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
self.weights_sd = None |
names = set() |
for lora in self.text_encoder_loras + self.unet_loras: |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
names.add(lora.lora_name) |
def set_multiplier(self, multiplier): |
self.multiplier = multiplier |
for lora in self.text_encoder_loras + self.unet_loras: |
lora.multiplier = self.multiplier |
def load_weights(self, file): |
if os.path.splitext(file)[1] == ".safetensors": |
from safetensors.torch import load_file, safe_open |
self.weights_sd = load_file(file) |
else: |
self.weights_sd = torch.load(file, map_location="cpu") |
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): |
if self.weights_sd: |
weights_has_text_encoder = weights_has_unet = False |
for key in self.weights_sd.keys(): |
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
weights_has_text_encoder = True |
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
weights_has_unet = True |
if apply_text_encoder is None: |
apply_text_encoder = weights_has_text_encoder |
else: |
assert ( |
apply_text_encoder == weights_has_text_encoder |
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" |
if apply_unet is None: |
apply_unet = weights_has_unet |
else: |
assert ( |
apply_unet == weights_has_unet |
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" |
else: |
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" |
if apply_text_encoder: |
print("enable LoRA for text encoder") |
else: |
self.text_encoder_loras = [] |
if apply_unet: |
print("enable LoRA for U-Net") |
else: |
self.unet_loras = [] |
for lora in self.text_encoder_loras + self.unet_loras: |
lora.apply_to() |
self.add_module(lora.lora_name, lora) |
if self.weights_sd: |
info = self.load_state_dict(self.weights_sd, False) |
print(f"weights are loaded: {info}") |
def merge_to(self, text_encoder, unet, dtype, device): |
assert self.weights_sd is not None, "weights are not loaded" |
apply_text_encoder = apply_unet = False |
for key in self.weights_sd.keys(): |
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
apply_text_encoder = True |
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
apply_unet = True |
if apply_text_encoder: |
print("enable LoRA for text encoder") |
else: |
self.text_encoder_loras = [] |
if apply_unet: |
print("enable LoRA for U-Net") |
else: |
self.unet_loras = [] |
for lora in self.text_encoder_loras + self.unet_loras: |
sd_for_lora = {} |
for key in self.weights_sd.keys(): |
if key.startswith(lora.lora_name): |
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] |
lora.merge_to(sd_for_lora, dtype, device) |
print(f"weights are merged") |
def enable_gradient_checkpointing(self): |
pass |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr): |
def enumerate_params(loras): |
params = [] |
for lora in loras: |
params.extend(lora.parameters()) |
return params |
self.requires_grad_(True) |
all_params = [] |
if self.text_encoder_loras: |
param_data = {"params": enumerate_params(self.text_encoder_loras)} |
if text_encoder_lr is not None: |
param_data["lr"] = text_encoder_lr |
all_params.append(param_data) |
if self.unet_loras: |
param_data = {"params": enumerate_params(self.unet_loras)} |
if unet_lr is not None: |
param_data["lr"] = unet_lr |
all_params.append(param_data) |
return all_params |
def prepare_grad_etc(self, text_encoder, unet): |
self.requires_grad_(True) |
def on_epoch_start(self, text_encoder, unet): |
self.train() |
def get_trainable_params(self): |
return self.parameters() |
def save_weights(self, file, dtype, metadata): |
if metadata is not None and len(metadata) == 0: |
metadata = None |
state_dict = self.state_dict() |
if dtype is not None: |
for key in list(state_dict.keys()): |
v = state_dict[key] |
v = v.detach().clone().to("cpu").to(dtype) |
state_dict[key] = v |
if os.path.splitext(file)[1] == ".safetensors": |
from safetensors.torch import save_file |
if metadata is None: |
metadata = {} |
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) |
metadata["sshs_model_hash"] = model_hash |
metadata["sshs_legacy_hash"] = legacy_hash |
save_file(state_dict, file, metadata) |
else: |
torch.save(state_dict, file) |
@staticmethod |
def set_regions(networks, image): |
image = image.astype(np.float32) / 255.0 |
for i, network in enumerate(networks[:3]): |
region = image[:, :, i] |
if region.max() == 0: |
continue |
region = torch.tensor(region) |
network.set_region(region) |
def set_region(self, region): |
for lora in self.unet_loras: |
lora.set_region(region) |
from io import BytesIO |
import safetensors.torch |
import hashlib |
def precalculate_safetensors_hashes(tensors, metadata): |
"""Precalculate the model hashes needed by sd-webui-additional-networks to |
save time on indexing the model later.""" |
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
bytes = safetensors.torch.save(tensors, metadata) |
b = BytesIO(bytes) |
model_hash = addnet_hash_safetensors(b) |
legacy_hash = addnet_hash_legacy(b) |
return model_hash, legacy_hash |
def addnet_hash_safetensors(b): |
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" |
hash_sha256 = hashlib.sha256() |
blksize = 1024 * 1024 |
b.seek(0) |
header = b.read(8) |
n = int.from_bytes(header, "little") |
offset = n + 8 |
b.seek(offset) |
for chunk in iter(lambda: b.read(blksize), b""): |
hash_sha256.update(chunk) |
return hash_sha256.hexdigest() |
def addnet_hash_legacy(b): |
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
m = hashlib.sha256() |
b.seek(0x100000) |
m.update(b.read(0x10000)) |
return m.hexdigest()[0:8] |