|
import re |
|
from sklearn.linear_model import PassiveAggressiveClassifier |
|
import torch |
|
import math |
|
import os |
|
import gc |
|
import gradio as gr |
|
from torchmetrics import Precision |
|
import modules.shared as shared |
|
import gc |
|
from safetensors.torch import load_file, save_file |
|
from typing import List |
|
from tqdm import tqdm |
|
from modules import sd_models,scripts |
|
from scripts.mergers.model_util import load_models_from_stable_diffusion_checkpoint,filenamecutter,savemodel |
|
from modules.ui import create_refresh_button |
|
|
|
def on_ui_tabs(): |
|
import lora |
|
sml_path_root = scripts.basedir() |
|
LWEIGHTSPRESETS="\ |
|
NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
|
ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\ |
|
INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
|
IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
|
INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
|
MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\ |
|
OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\ |
|
OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\ |
|
OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\ |
|
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5" |
|
sml_filepath = os.path.join(sml_path_root,"scripts", "lbwpresets.txt") |
|
sml_lbwpresets="" |
|
try: |
|
with open(sml_filepath,encoding="utf-8") as f: |
|
sml_lbwpresets = f.read() |
|
except OSError as e: |
|
sml_lbwpresets=LWEIGHTSPRESETS |
|
|
|
with gr.Blocks(analytics_enabled=False) : |
|
sml_submit_result = gr.Textbox(label="Message") |
|
with gr.Row().style(equal_height=False): |
|
sml_cpmerge = gr.Button(elem_id="model_merger_merge", value="Merge to Checkpoint",variant='primary') |
|
sml_makelora = gr.Button(elem_id="model_merger_merge", value="Make LoRA (alpha * A - beta * B)",variant='primary') |
|
sml_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True) |
|
create_refresh_button(sml_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
sml_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True) |
|
create_refresh_button(sml_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
with gr.Row().style(equal_height=False): |
|
sml_merge = gr.Button(elem_id="model_merger_merge", value="Merge LoRAs",variant='primary') |
|
alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=1) |
|
beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=1) |
|
with gr.Row().style(equal_height=False): |
|
sml_settings = gr.CheckboxGroup(["same to Strength", "overwrite"], label="settings") |
|
precision = gr.Radio(label = "save precision",choices=["float","fp16","bf16"],value = "fp16",type="value") |
|
with gr.Row().style(equal_height=False): |
|
sml_dim = gr.Radio(label = "remake dimension",choices = ["no","auto",*[2**(x+2) for x in range(9)]],value = "no",type = "value") |
|
sml_filename = gr.Textbox(label="filename(option)",lines=1,visible =True,interactive = True) |
|
sml_loranames = gr.Textbox(label='LoRAname1:ratio1:Blocks1,LoRAname2:ratio2:Blocks2,...(":blocks" is option, not necessary)',lines=1,value="",visible =True) |
|
sml_dims = gr.CheckboxGroup(label = "limit dimension",choices=[],value = [],type="value",interactive=True,visible = False) |
|
with gr.Row().style(equal_height=False): |
|
sml_calcdim = gr.Button(elem_id="calcloras", value="calculate dimension of LoRAs(It may take a few minutes if there are many LoRAs)",variant='primary') |
|
sml_update = gr.Button(elem_id="calcloras", value="update list",variant='primary') |
|
sml_loras = gr.CheckboxGroup(label = "Lora",choices=[x[0] for x in lora.available_loras.items()],type="value",interactive=True,visible = True) |
|
sml_loraratios = gr.TextArea(label="",value=sml_lbwpresets,visible =True,interactive = True) |
|
|
|
sml_merge.click( |
|
fn=lmerge, |
|
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_dim,precision], |
|
outputs=[sml_submit_result] |
|
) |
|
|
|
sml_makelora.click( |
|
fn=makelora, |
|
inputs=[sml_model_a,sml_model_b,sml_dim,sml_filename,sml_settings,alpha,beta,precision], |
|
outputs=[sml_submit_result] |
|
) |
|
|
|
sml_cpmerge.click( |
|
fn=pluslora, |
|
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_model_a,precision], |
|
outputs=[sml_submit_result] |
|
) |
|
llist ={} |
|
dlist =[] |
|
dn = [] |
|
|
|
def updateloras(): |
|
lora.list_available_loras() |
|
for n in lora.available_loras.items(): |
|
if n[0] not in llist:llist[n[0]] = "" |
|
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
|
|
|
sml_update.click(fn = updateloras,outputs = [sml_loras]) |
|
|
|
def calculatedim(): |
|
print("listing dimensions...") |
|
for n in tqdm(lora.available_loras.items()): |
|
if n[0] in llist: |
|
if llist[n[0]] !="": continue |
|
c_lora = lora.available_loras.get(n[0], None) |
|
d,t = dimgetter(c_lora.filename) |
|
if t == "LoCon" : d = f"{d}:{t}" |
|
if d not in dlist: |
|
if type(d) == int :dlist.append(d) |
|
elif d not in dn: dn.append(d) |
|
llist[n[0]] = d |
|
dlist.sort() |
|
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()],value =[]),gr.update(visible =True,choices = [x for x in (dlist+dn)]) |
|
|
|
sml_calcdim.click( |
|
fn=calculatedim, |
|
inputs=[], |
|
outputs=[sml_loras,sml_dims] |
|
) |
|
|
|
def dimselector(dims): |
|
if dims ==[]:return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
|
rl=[] |
|
for d in dims: |
|
for i in llist.items(): |
|
if d == i[1]:rl.append(f"{i[0]}({i[1]})") |
|
return gr.update(choices = [l for l in rl],value =[]) |
|
|
|
def llister(names): |
|
if names ==[] : return "" |
|
else: |
|
for i,n in enumerate(names): |
|
if "(" in n:names[i] = n[:n.rfind("(")] |
|
return ":1.0,".join(names)+":1.0" |
|
sml_loras.change(fn=llister,inputs=[sml_loras],outputs=[sml_loranames]) |
|
sml_dims.change(fn=dimselector,inputs=[sml_dims],outputs=[sml_loras]) |
|
|
|
def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,precision): |
|
print("make LoRA start") |
|
if model_a == "" or model_b =="": |
|
return "ERROR: No model Selected" |
|
gc.collect() |
|
|
|
if saveto =="" : saveto = makeloraname(model_a,model_b) |
|
if not ".safetensors" in saveto :saveto += ".safetensors" |
|
saveto = os.path.join(shared.cmd_opts.lora_dir,saveto) |
|
|
|
dim = 128 if type(dim) != int else int(dim) |
|
if os.path.isfile(saveto ) and not "overwrite" in settings: |
|
_err_msg = f"Output file ({saveto}) existed and was not saved" |
|
print(_err_msg) |
|
return _err_msg |
|
|
|
svd(fullpathfromname(model_a),fullpathfromname(model_b),False,dim,precision,saveto,alpha,beta) |
|
return f"saved to {saveto}" |
|
|
|
def lmerge(loranames,loraratioss,settings,filename,dim,precision): |
|
import lora |
|
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
|
if any([x is None for x in loras_on_disk]): |
|
lora.list_available_loras() |
|
|
|
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
|
|
|
lnames = [loranames] if "," not in loranames else loranames.split(",") |
|
|
|
for i, n in enumerate(lnames): |
|
lnames[i] = n.split(":") |
|
|
|
loraratios=loraratioss.splitlines() |
|
ldict ={} |
|
|
|
for i,l in enumerate(loraratios): |
|
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
|
ldict[l.split(":")[0]]=l.split(":")[1] |
|
|
|
ln = [] |
|
lr = [] |
|
ld = [] |
|
lt = [] |
|
dmax = 1 |
|
|
|
for i,n in enumerate(lnames): |
|
if len(n) ==3: |
|
if n[2].strip() in ldict: |
|
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
|
else:ratio = [float(n[1])]*17 |
|
else:ratio = [float(n[1])]*17 |
|
c_lora = lora.available_loras.get(n[0], None) |
|
ln.append(c_lora.filename) |
|
lr.append(ratio) |
|
d,t = dimgetter(c_lora.filename) |
|
lt.append(t) |
|
ld.append(d) |
|
if d != "LyCORIS": |
|
if d > dmax : dmax = d |
|
|
|
if filename =="":filename =loranames.replace(",","+").replace(":","_") |
|
if not ".safetensors" in filename:filename += ".safetensors" |
|
filename = os.path.join(shared.cmd_opts.lora_dir,filename) |
|
|
|
dim = int(dim) if dim != "no" and dim != "auto" else 0 |
|
|
|
if "LyCORIS" in ld or "LoCon" in lt: |
|
if len(ld) !=1: |
|
return "multiple merge of LyCORIS is not supported" |
|
sd = lycomerge(ln[0],lr[0]) |
|
elif dim > 0: |
|
print("change demension to ", dim) |
|
sd = merge_lora_models_dim(ln, lr, dim,settings) |
|
elif "auto" in settings and ld.count(ld[0]) != len(ld): |
|
print("change demension to ",dmax) |
|
sd = merge_lora_models_dim(ln, lr, dmax,settings) |
|
else: |
|
sd = merge_lora_models(ln, lr,settings) |
|
|
|
if os.path.isfile(filename) and not "overwrite" in settings: |
|
_err_msg = f"Output file ({filename}) existed and was not saved" |
|
print(_err_msg) |
|
return _err_msg |
|
|
|
save_to_file(filename,sd,sd, str_to_dtype(precision)) |
|
return "saved : "+filename |
|
|
|
def pluslora(lnames,loraratios,settings,output,model,precision): |
|
if model == []: |
|
return "ERROR: No model Selected" |
|
if lnames == "": |
|
return "ERROR: No LoRA Selected" |
|
|
|
print("plus LoRA start") |
|
import lora |
|
lnames = [lnames] if "," not in lnames else lnames.split(",") |
|
|
|
for i, n in enumerate(lnames): |
|
lnames[i] = n.split(":") |
|
|
|
loraratios=loraratios.splitlines() |
|
ldict ={} |
|
|
|
for i,l in enumerate(loraratios): |
|
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
|
ldict[l.split(":")[0].strip()]=l.split(":")[1] |
|
|
|
names=[] |
|
filenames=[] |
|
loratypes=[] |
|
lweis=[] |
|
|
|
for n in lnames: |
|
if len(n) ==3: |
|
if n[2].strip() in ldict: |
|
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
|
else:ratio = [float(n[1])]*17 |
|
else:ratio = [float(n[1])]*17 |
|
c_lora = lora.available_loras.get(n[0], None) |
|
names.append(n[0]) |
|
filenames.append(c_lora.filename) |
|
_,t = dimgetter(c_lora.filename) |
|
if "LyCORIS" in t: return "LyCORIS merge is not supported" |
|
lweis.append(ratio) |
|
|
|
modeln=filenamecutter(model,True) |
|
dname = modeln |
|
for n in names: |
|
dname = dname + "+"+n |
|
|
|
checkpoint_info = sd_models.get_closet_checkpoint_match(model) |
|
print(f"Loading {model}") |
|
theta_0 = sd_models.read_state_dict(checkpoint_info.filename,"cpu") |
|
|
|
keychanger = {} |
|
for key in theta_0.keys(): |
|
if "model" in key: |
|
skey = key.replace(".","_").replace("_weight","") |
|
keychanger[skey.split("model_",1)[1]] = key |
|
|
|
for name,filename, lwei in zip(names,filenames, lweis): |
|
print(f"loading: {name}") |
|
lora_sd = load_state_dict(filename, torch.float) |
|
|
|
print(f"merging..." ,lwei) |
|
for key in lora_sd.keys(): |
|
ratio = 1 |
|
|
|
fullkey = convert_diffusers_name_to_compvis(key) |
|
|
|
for i,block in enumerate(LORABLOCKS): |
|
if block in fullkey: |
|
ratio = lwei[i] |
|
|
|
msd_key, lora_key = fullkey.split(".", 1) |
|
|
|
if "lora_down" in key: |
|
up_key = key.replace("lora_down", "lora_up") |
|
alpha_key = key[:key.index("lora_down")] + 'alpha' |
|
|
|
|
|
|
|
down_weight = lora_sd[key].to(device="cpu") |
|
up_weight = lora_sd[up_key].to(device="cpu") |
|
|
|
dim = down_weight.size()[0] |
|
alpha = lora_sd.get(alpha_key, dim) |
|
scale = alpha / dim |
|
|
|
weight = theta_0[keychanger[msd_key]].to(device="cpu") |
|
|
|
if not len(down_weight.size()) == 4: |
|
|
|
weight = weight + ratio * (up_weight @ down_weight) * scale |
|
else: |
|
|
|
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
|
).unsqueeze(2).unsqueeze(3) * scale |
|
theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight) |
|
|
|
settings.append(precision) |
|
result = savemodel(theta_0,dname,output,settings,model) |
|
del theta_0 |
|
gc.collect() |
|
return result |
|
|
|
def save_to_file(file_name, model, state_dict, dtype): |
|
if dtype is not None: |
|
for key in list(state_dict.keys()): |
|
if type(state_dict[key]) == torch.Tensor: |
|
state_dict[key] = state_dict[key].to(dtype) |
|
|
|
if os.path.splitext(file_name)[1] == '.safetensors': |
|
save_file(model, file_name) |
|
else: |
|
torch.save(model, file_name) |
|
|
|
re_digits = re.compile(r"\d+") |
|
|
|
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") |
|
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") |
|
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") |
|
|
|
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)") |
|
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)") |
|
re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)") |
|
|
|
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)") |
|
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)") |
|
|
|
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") |
|
|
|
|
|
def convert_diffusers_name_to_compvis(key): |
|
def match(match_list, regex): |
|
r = re.match(regex, key) |
|
if not r: |
|
return False |
|
|
|
match_list.clear() |
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
|
return True |
|
|
|
m = [] |
|
|
|
if match(m, re_unet_down_blocks): |
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" |
|
|
|
if match(m, re_unet_mid_blocks): |
|
return f"diffusion_model_middle_block_1_{m[1]}" |
|
|
|
if match(m, re_unet_up_blocks): |
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" |
|
|
|
if match(m, re_unet_down_blocks_res): |
|
block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_" |
|
if m[2].startswith('conv1'): |
|
return f"{block}in_layers_2{m[2][len('conv1'):]}" |
|
elif m[2].startswith('conv2'): |
|
return f"{block}out_layers_3{m[2][len('conv2'):]}" |
|
elif m[2].startswith('time_emb_proj'): |
|
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
|
elif m[2].startswith('conv_shortcut'): |
|
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
|
|
|
if match(m, re_unet_mid_blocks_res): |
|
block = f"diffusion_model_middle_block_{m[0]*2}_" |
|
if m[1].startswith('conv1'): |
|
return f"{block}in_layers_2{m[1][len('conv1'):]}" |
|
elif m[1].startswith('conv2'): |
|
return f"{block}out_layers_3{m[1][len('conv2'):]}" |
|
elif m[1].startswith('time_emb_proj'): |
|
return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}" |
|
elif m[1].startswith('conv_shortcut'): |
|
return f"{block}skip_connection{m[1][len('conv_shortcut'):]}" |
|
|
|
if match(m, re_unet_up_blocks_res): |
|
block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_" |
|
if m[2].startswith('conv1'): |
|
return f"{block}in_layers_2{m[2][len('conv1'):]}" |
|
elif m[2].startswith('conv2'): |
|
return f"{block}out_layers_3{m[2][len('conv2'):]}" |
|
elif m[2].startswith('time_emb_proj'): |
|
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
|
elif m[2].startswith('conv_shortcut'): |
|
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
|
|
|
if match(m, re_unet_downsample): |
|
return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}" |
|
|
|
if match(m, re_unet_upsample): |
|
return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}" |
|
|
|
if match(m, re_text_block): |
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
|
|
|
return key |
|
|
|
CLAMP_QUANTILE = 0.99 |
|
MIN_DIFF = 1e-6 |
|
|
|
def str_to_dtype(p): |
|
if p == 'float': |
|
return torch.float |
|
if p == 'fp16': |
|
return torch.float16 |
|
if p == 'bf16': |
|
return torch.bfloat16 |
|
return None |
|
|
|
def svd(model_a,model_b,v2,dim,save_precision,save_to,alpha,beta): |
|
save_dtype = str_to_dtype(save_precision) |
|
|
|
if model_a == model_b: |
|
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
|
text_encoder_o, _, unet_o = text_encoder_t, _, unet_t |
|
else: |
|
print(f"loading SD model : {model_b}") |
|
text_encoder_o, _, unet_o = load_models_from_stable_diffusion_checkpoint(v2, model_b) |
|
|
|
print(f"loading SD model : {model_a}") |
|
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
|
|
|
|
|
lora_network_o = create_network(1.0, dim, dim, None, text_encoder_o, unet_o) |
|
lora_network_t = create_network(1.0, dim, dim, None, text_encoder_t, unet_t) |
|
assert len(lora_network_o.text_encoder_loras) == len( |
|
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " |
|
|
|
diffs = {} |
|
text_encoder_different = False |
|
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): |
|
lora_name = lora_o.lora_name |
|
module_o = lora_o.org_module |
|
module_t = lora_t.org_module |
|
diff = alpha*module_t.weight - beta*module_o.weight |
|
|
|
|
|
if torch.max(torch.abs(diff)) > MIN_DIFF: |
|
text_encoder_different = True |
|
|
|
diff = diff.float() |
|
diffs[lora_name] = diff |
|
|
|
if not text_encoder_different: |
|
print("Text encoder is same. Extract U-Net only.") |
|
lora_network_o.text_encoder_loras = [] |
|
diffs = {} |
|
|
|
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): |
|
lora_name = lora_o.lora_name |
|
module_o = lora_o.org_module |
|
module_t = lora_t.org_module |
|
diff = alpha*module_t.weight - beta*module_o.weight |
|
diff = diff.float() |
|
|
|
diffs[lora_name] = diff |
|
|
|
|
|
print("calculating by svd") |
|
rank = dim |
|
lora_weights = {} |
|
with torch.no_grad(): |
|
for lora_name, mat in tqdm(list(diffs.items())): |
|
conv2d = (len(mat.size()) == 4) |
|
if conv2d: |
|
mat = mat.squeeze() |
|
|
|
U, S, Vh = torch.linalg.svd(mat) |
|
|
|
U = U[:, :rank] |
|
S = S[:rank] |
|
U = U @ torch.diag(S) |
|
|
|
Vh = Vh[:rank, :] |
|
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
|
low_val = -hi_val |
|
|
|
U = U.clamp(low_val, hi_val) |
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
|
lora_weights[lora_name] = (U, Vh) |
|
|
|
|
|
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) |
|
lora_sd = lora_network_o.state_dict() |
|
print(f"LoRA has {len(lora_sd)} weights.") |
|
|
|
for key in list(lora_sd.keys()): |
|
if "alpha" in key: |
|
continue |
|
|
|
lora_name = key.split('.')[0] |
|
i = 0 if "lora_up" in key else 1 |
|
|
|
weights = lora_weights[lora_name][i] |
|
|
|
if len(lora_sd[key].size()) == 4: |
|
weights = weights.unsqueeze(2).unsqueeze(3) |
|
|
|
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" |
|
lora_sd[key] = weights |
|
|
|
|
|
info = lora_network_o.load_state_dict(lora_sd) |
|
print(f"Loading extracted LoRA weights: {info}") |
|
|
|
dir_name = os.path.dirname(save_to) |
|
if dir_name and not os.path.exists(dir_name): |
|
os.makedirs(dir_name, exist_ok=True) |
|
|
|
|
|
metadata = {"ss_network_dim": str(dim), "ss_network_alpha": str(dim)} |
|
|
|
lora_network_o.save_weights(save_to, save_dtype, metadata) |
|
print(f"LoRA weights are saved to: {save_to}") |
|
return save_to |
|
|
|
def load_state_dict(file_name, dtype): |
|
if os.path.splitext(file_name)[1] == '.safetensors': |
|
sd = load_file(file_name) |
|
else: |
|
sd = torch.load(file_name, map_location='cpu') |
|
for key in list(sd.keys()): |
|
if type(sd[key]) == torch.Tensor: |
|
sd[key] = sd[key].to(dtype) |
|
return sd |
|
|
|
def dimgetter(filename): |
|
lora_sd = load_state_dict(filename, torch.float) |
|
alpha = None |
|
dim = None |
|
type = None |
|
|
|
if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys(): |
|
type = "LoCon" |
|
|
|
for key, value in lora_sd.items(): |
|
|
|
if alpha is None and 'alpha' in key: |
|
alpha = value |
|
if dim is None and 'lora_down' in key and len(value.size()) == 2: |
|
dim = value.size()[0] |
|
if "hada_" in key: |
|
dim,type = "LyCORIS","LyCORIS" |
|
if alpha is not None and dim is not None: |
|
break |
|
if alpha is None: |
|
alpha = dim |
|
if type == None:type = "LoRA" |
|
if dim : |
|
return dim,type |
|
else: |
|
return "unknown","unknown" |
|
|
|
def blockfromkey(key): |
|
fullkey = convert_diffusers_name_to_compvis(key) |
|
for i,n in enumerate(LORABLOCKS): |
|
if n in fullkey: return i |
|
return 0 |
|
|
|
def merge_lora_models_dim(models, ratios, new_rank,sets): |
|
merged_sd = {} |
|
fugou = 1 |
|
for model, ratios in zip(models, ratios): |
|
merge_dtype = torch.float |
|
lora_sd = load_state_dict(model, merge_dtype) |
|
|
|
|
|
print(f"merging {model}: {ratios}") |
|
for key in tqdm(list(lora_sd.keys())): |
|
if 'lora_down' not in key: |
|
continue |
|
lora_module_name = key[:key.rfind(".lora_down")] |
|
|
|
down_weight = lora_sd[key] |
|
network_dim = down_weight.size()[0] |
|
|
|
up_weight = lora_sd[lora_module_name + '.lora_up.weight'] |
|
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) |
|
|
|
in_dim = down_weight.size()[1] |
|
out_dim = up_weight.size()[0] |
|
conv2d = len(down_weight.size()) == 4 |
|
|
|
|
|
|
|
if lora_module_name not in merged_sd: |
|
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype) |
|
else: |
|
weight = merged_sd[lora_module_name] |
|
|
|
ratio = ratios[blockfromkey(key)] |
|
if "same to Strength" in sets: |
|
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
|
|
|
|
|
scale = (alpha / network_dim) |
|
if not conv2d: |
|
weight = weight + ratio * (up_weight @ down_weight) * scale * fugou |
|
else: |
|
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
|
).unsqueeze(2).unsqueeze(3) * scale * fugou |
|
|
|
merged_sd[lora_module_name] = weight |
|
|
|
|
|
print("extract new lora...") |
|
merged_lora_sd = {} |
|
with torch.no_grad(): |
|
for lora_module_name, mat in tqdm(list(merged_sd.items())): |
|
conv2d = (len(mat.size()) == 4) |
|
if conv2d: |
|
mat = mat.squeeze() |
|
|
|
U, S, Vh = torch.linalg.svd(mat) |
|
|
|
U = U[:, :new_rank] |
|
S = S[:new_rank] |
|
U = U @ torch.diag(S) |
|
|
|
Vh = Vh[:new_rank, :] |
|
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
|
low_val = -hi_val |
|
|
|
U = U.clamp(low_val, hi_val) |
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
|
up_weight = U |
|
down_weight = Vh |
|
|
|
if conv2d: |
|
up_weight = up_weight.unsqueeze(2).unsqueeze(3) |
|
down_weight = down_weight.unsqueeze(2).unsqueeze(3) |
|
|
|
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() |
|
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() |
|
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank) |
|
|
|
return merged_lora_sd |
|
|
|
def merge_lora_models(models, ratios,sets): |
|
base_alphas = {} |
|
base_dims = {} |
|
merge_dtype = torch.float |
|
merged_sd = {} |
|
fugou = 1 |
|
for model, ratios in zip(models, ratios): |
|
print(f"merging {model}: {ratios}") |
|
lora_sd = load_state_dict(model, merge_dtype) |
|
|
|
|
|
alphas = {} |
|
dims = {} |
|
for key in lora_sd.keys(): |
|
if 'alpha' in key: |
|
lora_module_name = key[:key.rfind(".alpha")] |
|
alpha = float(lora_sd[key].detach().numpy()) |
|
alphas[lora_module_name] = alpha |
|
if lora_module_name not in base_alphas: |
|
base_alphas[lora_module_name] = alpha |
|
elif "lora_down" in key: |
|
lora_module_name = key[:key.rfind(".lora_down")] |
|
dim = lora_sd[key].size()[0] |
|
dims[lora_module_name] = dim |
|
if lora_module_name not in base_dims: |
|
base_dims[lora_module_name] = dim |
|
|
|
for lora_module_name in dims.keys(): |
|
if lora_module_name not in alphas: |
|
alpha = dims[lora_module_name] |
|
alphas[lora_module_name] = alpha |
|
if lora_module_name not in base_alphas: |
|
base_alphas[lora_module_name] = alpha |
|
|
|
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") |
|
|
|
|
|
print(f"merging...") |
|
for key in lora_sd.keys(): |
|
if 'alpha' in key: |
|
continue |
|
if "lora_down" in key: dwon = True |
|
lora_module_name = key[:key.rfind(".lora_")] |
|
|
|
base_alpha = base_alphas[lora_module_name] |
|
alpha = alphas[lora_module_name] |
|
|
|
ratio = ratios[blockfromkey(key)] |
|
if "same to Strength" in sets: |
|
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
|
|
|
if "lora_down" in key: |
|
ratio = ratio * fugou |
|
|
|
scale = math.sqrt(alpha / base_alpha) * ratio |
|
|
|
if key in merged_sd: |
|
assert merged_sd[key].size() == lora_sd[key].size( |
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" |
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale |
|
else: |
|
merged_sd[key] = lora_sd[key] * scale |
|
|
|
|
|
for lora_module_name, alpha in base_alphas.items(): |
|
key = lora_module_name + ".alpha" |
|
merged_sd[key] = torch.tensor(alpha) |
|
|
|
print("merged model") |
|
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") |
|
|
|
return merged_sd |
|
|
|
def fullpathfromname(name): |
|
if hash == "" or hash ==[]: return "" |
|
checkpoint_info = sd_models.get_closet_checkpoint_match(name) |
|
return checkpoint_info.filename |
|
|
|
def makeloraname(model_a,model_b): |
|
model_a=filenamecutter(model_a) |
|
model_b=filenamecutter(model_b) |
|
return "lora_"+model_a+"-"+model_b |
|
|
|
def lycomerge(filename,ratios): |
|
sd = load_state_dict(filename, torch.float) |
|
|
|
if len(ratios) == 17: |
|
r0 = 1 |
|
ratios = [ratios[0]] + [r0] + ratios[1:3]+ [r0] + ratios[3:5]+[r0] + ratios[5:7]+[r0,r0,r0] + [ratios[7]] + [r0,r0,r0] + ratios[8:] |
|
|
|
print("LyCORIS: " , ratios) |
|
|
|
keys_failed_to_match = [] |
|
|
|
for lkey, weight in sd.items(): |
|
ratio = 1 |
|
picked = False |
|
if 'alpha' in lkey: |
|
continue |
|
|
|
fullkey = convert_diffusers_name_to_compvis(lkey) |
|
key, lora_key = fullkey.split(".", 1) |
|
|
|
for i,block in enumerate(LYCOBLOCKS): |
|
if block in key: |
|
ratio = ratios[i] |
|
picked = True |
|
if not picked: keys_failed_to_match.append(key) |
|
|
|
sd[lkey] = weight * math.sqrt(abs(float(ratio))) |
|
|
|
if "down" in lkey and ratio < 0: |
|
sd[key] = sd[key] * -1 |
|
|
|
if len(keys_failed_to_match) > 0: |
|
print(keys_failed_to_match) |
|
|
|
return sd |
|
|
|
LORABLOCKS=["encoder", |
|
"diffusion_model_input_blocks_1_", |
|
"diffusion_model_input_blocks_2_", |
|
"diffusion_model_input_blocks_4_", |
|
"diffusion_model_input_blocks_5_", |
|
"diffusion_model_input_blocks_7_", |
|
"diffusion_model_input_blocks_8_", |
|
"diffusion_model_middle_block_", |
|
"diffusion_model_output_blocks_3_", |
|
"diffusion_model_output_blocks_4_", |
|
"diffusion_model_output_blocks_5_", |
|
"diffusion_model_output_blocks_6_", |
|
"diffusion_model_output_blocks_7_", |
|
"diffusion_model_output_blocks_8_", |
|
"diffusion_model_output_blocks_9_", |
|
"diffusion_model_output_blocks_10_", |
|
"diffusion_model_output_blocks_11_"] |
|
|
|
LYCOBLOCKS=["encoder", |
|
"diffusion_model_input_blocks_0_", |
|
"diffusion_model_input_blocks_1_", |
|
"diffusion_model_input_blocks_2_", |
|
"diffusion_model_input_blocks_3_", |
|
"diffusion_model_input_blocks_4_", |
|
"diffusion_model_input_blocks_5_", |
|
"diffusion_model_input_blocks_6_", |
|
"diffusion_model_input_blocks_7_", |
|
"diffusion_model_input_blocks_8_", |
|
"diffusion_model_input_blocks_9_", |
|
"diffusion_model_input_blocks_10_", |
|
"diffusion_model_input_blocks_11_", |
|
"diffusion_model_middle_block_", |
|
"diffusion_model_output_blocks_0_", |
|
"diffusion_model_output_blocks_1_", |
|
"diffusion_model_output_blocks_2_", |
|
"diffusion_model_output_blocks_3_", |
|
"diffusion_model_output_blocks_4_", |
|
"diffusion_model_output_blocks_5_", |
|
"diffusion_model_output_blocks_6_", |
|
"diffusion_model_output_blocks_7_", |
|
"diffusion_model_output_blocks_8_", |
|
"diffusion_model_output_blocks_9_", |
|
"diffusion_model_output_blocks_10_", |
|
"diffusion_model_output_blocks_11_"] |
|
|
|
class LoRAModule(torch.nn.Module): |
|
""" |
|
replaces forward method of the original Linear, instead of replacing the original Linear module. |
|
""" |
|
|
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): |
|
"""if alpha == 0 or None, alpha is rank (no scaling).""" |
|
super().__init__() |
|
self.lora_name = lora_name |
|
|
|
if org_module.__class__.__name__ == "Conv2d": |
|
in_dim = org_module.in_channels |
|
out_dim = org_module.out_channels |
|
else: |
|
in_dim = org_module.in_features |
|
out_dim = org_module.out_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lora_dim = lora_dim |
|
|
|
if org_module.__class__.__name__ == "Conv2d": |
|
kernel_size = org_module.kernel_size |
|
stride = org_module.stride |
|
padding = org_module.padding |
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) |
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) |
|
else: |
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
|
|
|
if type(alpha) == torch.Tensor: |
|
alpha = alpha.detach().float().numpy() |
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
|
self.scale = alpha / self.lora_dim |
|
self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
|
|
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
|
torch.nn.init.zeros_(self.lora_up.weight) |
|
|
|
self.multiplier = multiplier |
|
self.org_module = org_module |
|
self.region = None |
|
self.region_mask = None |
|
|
|
def apply_to(self): |
|
self.org_forward = self.org_module.forward |
|
self.org_module.forward = self.forward |
|
del self.org_module |
|
|
|
def merge_to(self, sd, dtype, device): |
|
|
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device) |
|
down_weight = sd["lora_down.weight"].to(torch.float).to(device) |
|
|
|
|
|
org_sd = self.org_module.state_dict() |
|
weight = org_sd["weight"].to(torch.float) |
|
|
|
|
|
if len(weight.size()) == 2: |
|
|
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale |
|
elif down_weight.size()[2:4] == (1, 1): |
|
|
|
weight = ( |
|
weight |
|
+ self.multiplier |
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
|
* self.scale |
|
) |
|
else: |
|
|
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) |
|
|
|
weight = weight + self.multiplier * conved * self.scale |
|
|
|
|
|
org_sd["weight"] = weight.to(dtype) |
|
self.org_module.load_state_dict(org_sd) |
|
|
|
def set_region(self, region): |
|
self.region = region |
|
self.region_mask = None |
|
|
|
def forward(self, x): |
|
if self.region is None: |
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
|
|
|
|
|
if x.size()[1] % 77 == 0: |
|
|
|
self.region = None |
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
|
|
|
|
|
if self.region_mask is None: |
|
if len(x.size()) == 4: |
|
h, w = x.size()[2:4] |
|
else: |
|
seq_len = x.size()[1] |
|
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) |
|
h = int(self.region.size()[0] / ratio + 0.5) |
|
w = seq_len // h |
|
|
|
r = self.region.to(x.device) |
|
if r.dtype == torch.bfloat16: |
|
r = r.to(torch.float) |
|
r = r.unsqueeze(0).unsqueeze(1) |
|
|
|
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") |
|
r = r.to(x.dtype) |
|
|
|
if len(x.size()) == 3: |
|
r = torch.reshape(r, (1, x.size()[1], -1)) |
|
|
|
self.region_mask = r |
|
|
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask |
|
|
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): |
|
if network_dim is None: |
|
network_dim = 4 |
|
|
|
|
|
conv_dim = kwargs.get("conv_dim", None) |
|
conv_alpha = kwargs.get("conv_alpha", None) |
|
if conv_dim is not None: |
|
conv_dim = int(conv_dim) |
|
if conv_alpha is None: |
|
conv_alpha = 1.0 |
|
else: |
|
conv_alpha = float(conv_alpha) |
|
|
|
""" |
|
block_dims = kwargs.get("block_dims") |
|
block_alphas = None |
|
if block_dims is not None: |
|
block_dims = [int(d) for d in block_dims.split(',')] |
|
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
|
block_alphas = kwargs.get("block_alphas") |
|
if block_alphas is None: |
|
block_alphas = [1] * len(block_dims) |
|
else: |
|
block_alphas = [int(a) for a in block_alphas(',')] |
|
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
|
conv_block_dims = kwargs.get("conv_block_dims") |
|
conv_block_alphas = None |
|
if conv_block_dims is not None: |
|
conv_block_dims = [int(d) for d in conv_block_dims.split(',')] |
|
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
|
conv_block_alphas = kwargs.get("conv_block_alphas") |
|
if conv_block_alphas is None: |
|
conv_block_alphas = [1] * len(conv_block_dims) |
|
else: |
|
conv_block_alphas = [int(a) for a in conv_block_alphas(',')] |
|
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
|
""" |
|
|
|
network = LoRANetwork( |
|
text_encoder, |
|
unet, |
|
multiplier=multiplier, |
|
lora_dim=network_dim, |
|
alpha=network_alpha, |
|
conv_lora_dim=conv_dim, |
|
conv_alpha=conv_alpha, |
|
) |
|
return network |
|
|
|
|
|
|
|
class LoRANetwork(torch.nn.Module): |
|
|
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] |
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] |
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
|
def __init__( |
|
self, |
|
text_encoder, |
|
unet, |
|
multiplier=1.0, |
|
lora_dim=4, |
|
alpha=1, |
|
conv_lora_dim=None, |
|
conv_alpha=None, |
|
modules_dim=None, |
|
modules_alpha=None, |
|
) -> None: |
|
super().__init__() |
|
self.multiplier = multiplier |
|
|
|
self.lora_dim = lora_dim |
|
self.alpha = alpha |
|
self.conv_lora_dim = conv_lora_dim |
|
self.conv_alpha = conv_alpha |
|
|
|
if modules_dim is not None: |
|
print(f"create LoRA network from weights") |
|
else: |
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") |
|
|
|
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None |
|
if self.apply_to_conv2d_3x3: |
|
if self.conv_alpha is None: |
|
self.conv_alpha = self.alpha |
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") |
|
|
|
|
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: |
|
loras = [] |
|
for name, module in root_module.named_modules(): |
|
if module.__class__.__name__ in target_replace_modules: |
|
|
|
for child_name, child_module in module.named_modules(): |
|
is_linear = child_module.__class__.__name__ == "Linear" |
|
is_conv2d = child_module.__class__.__name__ == "Conv2d" |
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
|
if is_linear or is_conv2d: |
|
lora_name = prefix + "." + name + "." + child_name |
|
lora_name = lora_name.replace(".", "_") |
|
|
|
if modules_dim is not None: |
|
if lora_name not in modules_dim: |
|
continue |
|
dim = modules_dim[lora_name] |
|
alpha = modules_alpha[lora_name] |
|
else: |
|
if is_linear or is_conv2d_1x1: |
|
dim = self.lora_dim |
|
alpha = self.alpha |
|
elif self.apply_to_conv2d_3x3: |
|
dim = self.conv_lora_dim |
|
alpha = self.conv_alpha |
|
else: |
|
continue |
|
|
|
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) |
|
loras.append(lora) |
|
return loras |
|
|
|
self.text_encoder_loras = create_modules( |
|
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE |
|
) |
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
|
|
|
|
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE |
|
if modules_dim is not None or self.conv_lora_dim is not None: |
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
|
|
|
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) |
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
|
|
|
self.weights_sd = None |
|
|
|
|
|
names = set() |
|
for lora in self.text_encoder_loras + self.unet_loras: |
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
|
names.add(lora.lora_name) |
|
|
|
def set_multiplier(self, multiplier): |
|
self.multiplier = multiplier |
|
for lora in self.text_encoder_loras + self.unet_loras: |
|
lora.multiplier = self.multiplier |
|
|
|
def load_weights(self, file): |
|
if os.path.splitext(file)[1] == ".safetensors": |
|
from safetensors.torch import load_file, safe_open |
|
|
|
self.weights_sd = load_file(file) |
|
else: |
|
self.weights_sd = torch.load(file, map_location="cpu") |
|
|
|
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): |
|
if self.weights_sd: |
|
weights_has_text_encoder = weights_has_unet = False |
|
for key in self.weights_sd.keys(): |
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
|
weights_has_text_encoder = True |
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
|
weights_has_unet = True |
|
|
|
if apply_text_encoder is None: |
|
apply_text_encoder = weights_has_text_encoder |
|
else: |
|
assert ( |
|
apply_text_encoder == weights_has_text_encoder |
|
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" |
|
|
|
if apply_unet is None: |
|
apply_unet = weights_has_unet |
|
else: |
|
assert ( |
|
apply_unet == weights_has_unet |
|
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" |
|
else: |
|
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" |
|
|
|
if apply_text_encoder: |
|
print("enable LoRA for text encoder") |
|
else: |
|
self.text_encoder_loras = [] |
|
|
|
if apply_unet: |
|
print("enable LoRA for U-Net") |
|
else: |
|
self.unet_loras = [] |
|
|
|
for lora in self.text_encoder_loras + self.unet_loras: |
|
lora.apply_to() |
|
self.add_module(lora.lora_name, lora) |
|
|
|
if self.weights_sd: |
|
|
|
info = self.load_state_dict(self.weights_sd, False) |
|
print(f"weights are loaded: {info}") |
|
|
|
|
|
def merge_to(self, text_encoder, unet, dtype, device): |
|
assert self.weights_sd is not None, "weights are not loaded" |
|
|
|
apply_text_encoder = apply_unet = False |
|
for key in self.weights_sd.keys(): |
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
|
apply_text_encoder = True |
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
|
apply_unet = True |
|
|
|
if apply_text_encoder: |
|
print("enable LoRA for text encoder") |
|
else: |
|
self.text_encoder_loras = [] |
|
|
|
if apply_unet: |
|
print("enable LoRA for U-Net") |
|
else: |
|
self.unet_loras = [] |
|
|
|
for lora in self.text_encoder_loras + self.unet_loras: |
|
sd_for_lora = {} |
|
for key in self.weights_sd.keys(): |
|
if key.startswith(lora.lora_name): |
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] |
|
lora.merge_to(sd_for_lora, dtype, device) |
|
print(f"weights are merged") |
|
|
|
def enable_gradient_checkpointing(self): |
|
|
|
pass |
|
|
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr): |
|
def enumerate_params(loras): |
|
params = [] |
|
for lora in loras: |
|
params.extend(lora.parameters()) |
|
return params |
|
|
|
self.requires_grad_(True) |
|
all_params = [] |
|
|
|
if self.text_encoder_loras: |
|
param_data = {"params": enumerate_params(self.text_encoder_loras)} |
|
if text_encoder_lr is not None: |
|
param_data["lr"] = text_encoder_lr |
|
all_params.append(param_data) |
|
|
|
if self.unet_loras: |
|
param_data = {"params": enumerate_params(self.unet_loras)} |
|
if unet_lr is not None: |
|
param_data["lr"] = unet_lr |
|
all_params.append(param_data) |
|
|
|
return all_params |
|
|
|
def prepare_grad_etc(self, text_encoder, unet): |
|
self.requires_grad_(True) |
|
|
|
def on_epoch_start(self, text_encoder, unet): |
|
self.train() |
|
|
|
def get_trainable_params(self): |
|
return self.parameters() |
|
|
|
def save_weights(self, file, dtype, metadata): |
|
if metadata is not None and len(metadata) == 0: |
|
metadata = None |
|
|
|
state_dict = self.state_dict() |
|
|
|
if dtype is not None: |
|
for key in list(state_dict.keys()): |
|
v = state_dict[key] |
|
v = v.detach().clone().to("cpu").to(dtype) |
|
state_dict[key] = v |
|
|
|
if os.path.splitext(file)[1] == ".safetensors": |
|
from safetensors.torch import save_file |
|
|
|
|
|
if metadata is None: |
|
metadata = {} |
|
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) |
|
metadata["sshs_model_hash"] = model_hash |
|
metadata["sshs_legacy_hash"] = legacy_hash |
|
|
|
save_file(state_dict, file, metadata) |
|
else: |
|
torch.save(state_dict, file) |
|
|
|
@staticmethod |
|
def set_regions(networks, image): |
|
image = image.astype(np.float32) / 255.0 |
|
for i, network in enumerate(networks[:3]): |
|
|
|
region = image[:, :, i] |
|
if region.max() == 0: |
|
continue |
|
region = torch.tensor(region) |
|
network.set_region(region) |
|
|
|
def set_region(self, region): |
|
for lora in self.unet_loras: |
|
lora.set_region(region) |
|
|
|
from io import BytesIO |
|
import safetensors.torch |
|
import hashlib |
|
|
|
def precalculate_safetensors_hashes(tensors, metadata): |
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to |
|
save time on indexing the model later.""" |
|
|
|
|
|
|
|
|
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
|
bytes = safetensors.torch.save(tensors, metadata) |
|
b = BytesIO(bytes) |
|
|
|
model_hash = addnet_hash_safetensors(b) |
|
legacy_hash = addnet_hash_legacy(b) |
|
return model_hash, legacy_hash |
|
|
|
def addnet_hash_safetensors(b): |
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
b.seek(0) |
|
header = b.read(8) |
|
n = int.from_bytes(header, "little") |
|
|
|
offset = n + 8 |
|
b.seek(offset) |
|
for chunk in iter(lambda: b.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
|
|
def addnet_hash_legacy(b): |
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
m = hashlib.sha256() |
|
|
|
b.seek(0x100000) |
|
m.update(b.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
|