|
|
|
import os |
|
import subprocess |
|
|
|
|
|
import gradio as gr |
|
from easygui import msgbox |
|
|
|
|
|
from .common_gui import get_saveasfilename_path, get_file_path |
|
from library.custom_logging import setup_logging |
|
|
|
|
|
log = setup_logging() |
|
|
|
folder_symbol = '\U0001f4c2' |
|
refresh_symbol = '\U0001f504' |
|
save_style_symbol = '\U0001f4be' |
|
document_symbol = '\U0001F4C4' |
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' |
|
|
|
|
|
def check_model(model): |
|
if not model: |
|
return True |
|
if not os.path.isfile(model): |
|
msgbox(f'The provided {model} is not a file') |
|
return False |
|
return True |
|
|
|
|
|
def verify_conditions(sd_model, lora_models): |
|
lora_models_count = sum(1 for model in lora_models if model) |
|
if sd_model and lora_models_count >= 1: |
|
return True |
|
elif not sd_model and lora_models_count >= 2: |
|
return True |
|
return False |
|
|
|
|
|
def merge_lora( |
|
sd_model, |
|
sdxl_model, |
|
lora_a_model, |
|
lora_b_model, |
|
lora_c_model, |
|
lora_d_model, |
|
ratio_a, |
|
ratio_b, |
|
ratio_c, |
|
ratio_d, |
|
save_to, |
|
precision, |
|
save_precision, |
|
): |
|
log.info('Merge model...') |
|
models = [sd_model, lora_a_model, lora_b_model, lora_c_model, lora_d_model] |
|
lora_models = models[1:] |
|
ratios = [ratio_a, ratio_b, ratio_c, ratio_d] |
|
|
|
if not verify_conditions(sd_model, lora_models): |
|
log.info( |
|
'Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided.' |
|
) |
|
return |
|
|
|
for model in models: |
|
if not check_model(model): |
|
return |
|
|
|
if not sdxl_model: |
|
run_cmd = f'{PYTHON} "{os.path.join("networks","merge_lora.py")}"' |
|
else: |
|
run_cmd = f'{PYTHON} "{os.path.join("networks","sdxl_merge_lora.py")}"' |
|
if sd_model: |
|
run_cmd += f' --sd_model "{sd_model}"' |
|
run_cmd += f' --save_precision {save_precision}' |
|
run_cmd += f' --precision {precision}' |
|
run_cmd += f' --save_to "{save_to}"' |
|
|
|
|
|
models_cmd = ' '.join([f'"{model}"' for model in lora_models if model]) |
|
|
|
|
|
valid_ratios = [ratios[i] for i, model in enumerate(lora_models) if model] |
|
ratios_cmd = ' '.join([str(ratio) for ratio in valid_ratios]) |
|
|
|
if models_cmd: |
|
run_cmd += f' --models {models_cmd}' |
|
run_cmd += f' --ratios {ratios_cmd}' |
|
|
|
log.info(run_cmd) |
|
|
|
|
|
if os.name == 'posix': |
|
os.system(run_cmd) |
|
else: |
|
subprocess.run(run_cmd) |
|
|
|
log.info('Done merging...') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_merge_lora_tab(headless=False): |
|
with gr.Tab('Merge LoRA'): |
|
gr.Markdown( |
|
'This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint.' |
|
) |
|
|
|
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) |
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) |
|
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) |
|
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) |
|
|
|
with gr.Row(): |
|
sd_model = gr.Textbox( |
|
label='SD Model', |
|
placeholder='(Optional) Stable Diffusion model', |
|
interactive=True, |
|
info='Provide a SD file path IF you want to merge it with LoRA files', |
|
) |
|
sd_model_file = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
sd_model_file.click( |
|
get_file_path, |
|
inputs=[sd_model, ckpt_ext, ckpt_ext_name], |
|
outputs=sd_model, |
|
show_progress=False, |
|
) |
|
sdxl_model = gr.Checkbox(label='SDXL model', value=False) |
|
|
|
with gr.Row(): |
|
lora_a_model = gr.Textbox( |
|
label='LoRA model "A"', |
|
placeholder='Path to the LoRA A model', |
|
interactive=True, |
|
) |
|
button_lora_a_model_file = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
button_lora_a_model_file.click( |
|
get_file_path, |
|
inputs=[lora_a_model, lora_ext, lora_ext_name], |
|
outputs=lora_a_model, |
|
show_progress=False, |
|
) |
|
|
|
lora_b_model = gr.Textbox( |
|
label='LoRA model "B"', |
|
placeholder='Path to the LoRA B model', |
|
interactive=True, |
|
) |
|
button_lora_b_model_file = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
button_lora_b_model_file.click( |
|
get_file_path, |
|
inputs=[lora_b_model, lora_ext, lora_ext_name], |
|
outputs=lora_b_model, |
|
show_progress=False, |
|
) |
|
|
|
with gr.Row(): |
|
ratio_a = gr.Slider( |
|
label='Model A merge ratio (eg: 0.5 mean 50%)', |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.0, |
|
interactive=True, |
|
) |
|
|
|
ratio_b = gr.Slider( |
|
label='Model B merge ratio (eg: 0.5 mean 50%)', |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.0, |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
lora_c_model = gr.Textbox( |
|
label='LoRA model "C"', |
|
placeholder='Path to the LoRA C model', |
|
interactive=True, |
|
) |
|
button_lora_c_model_file = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
button_lora_c_model_file.click( |
|
get_file_path, |
|
inputs=[lora_c_model, lora_ext, lora_ext_name], |
|
outputs=lora_c_model, |
|
show_progress=False, |
|
) |
|
|
|
lora_d_model = gr.Textbox( |
|
label='LoRA model "D"', |
|
placeholder='Path to the LoRA D model', |
|
interactive=True, |
|
) |
|
button_lora_d_model_file = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
button_lora_d_model_file.click( |
|
get_file_path, |
|
inputs=[lora_d_model, lora_ext, lora_ext_name], |
|
outputs=lora_d_model, |
|
show_progress=False, |
|
) |
|
|
|
with gr.Row(): |
|
ratio_c = gr.Slider( |
|
label='Model C merge ratio (eg: 0.5 mean 50%)', |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.0, |
|
interactive=True, |
|
) |
|
|
|
ratio_d = gr.Slider( |
|
label='Model D merge ratio (eg: 0.5 mean 50%)', |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.0, |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
save_to = gr.Textbox( |
|
label='Save to', |
|
placeholder='path for the file to save...', |
|
interactive=True, |
|
) |
|
button_save_to = gr.Button( |
|
folder_symbol, |
|
elem_id='open_folder_small', |
|
visible=(not headless), |
|
) |
|
button_save_to.click( |
|
get_saveasfilename_path, |
|
inputs=[save_to, lora_ext, lora_ext_name], |
|
outputs=save_to, |
|
show_progress=False, |
|
) |
|
precision = gr.Dropdown( |
|
label='Merge precision', |
|
choices=['fp16', 'bf16', 'float'], |
|
value='float', |
|
interactive=True, |
|
) |
|
save_precision = gr.Dropdown( |
|
label='Save precision', |
|
choices=['fp16', 'bf16', 'float'], |
|
value='fp16', |
|
interactive=True, |
|
) |
|
|
|
merge_button = gr.Button('Merge model') |
|
|
|
merge_button.click( |
|
merge_lora, |
|
inputs=[ |
|
sd_model, |
|
sdxl_model, |
|
lora_a_model, |
|
lora_b_model, |
|
lora_c_model, |
|
lora_d_model, |
|
ratio_a, |
|
ratio_b, |
|
ratio_c, |
|
ratio_d, |
|
save_to, |
|
precision, |
|
save_precision, |
|
], |
|
show_progress=False, |
|
) |
|
|