import io
import safetensors
import streamlit.file_util
from safetensors.torch import serialize
from streamlit.runtime.uploaded_file_manager import UploadedFile
from tools import lora_tools, torch_tools
# https://huggingface.co/docs/hub/spaces-config-reference
streamlit.title("Lora and Embedding Tools")
output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0)
streamlit.container()
col1, col2 = streamlit.columns(2, gap="medium")
# A helper method to wipe a download button once invoked
def completed_download_callback():
ui_filedownload_rescale.empty()
ui_filedownload_stripclip.empty()
ui_filedownload_ckpt.empty()
with col1:
# A tool for rescaling the strength of Lora weights
streamlit.html("
Rescale Lora Strength
")
ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"]) # type: UploadedFile
new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01)
# Let's preallocate the download button here so it's in the correct column, we can just add the button later.
ui_filedownload_rescale = streamlit.empty()
with col2:
# A tool for removing CLIP parameters from a Lora file
streamlit.html("Remove CLIP Parameters
")
ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"]) # type: UploadedFile
# Preallocate download button
ui_filedownload_stripclip = streamlit.empty()
streamlit.html("
")
# A tool for converting a .ckpt file to a .safetensors file
streamlit.html("Convert CKPT to Safetensors (700MB max)
")
ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt"]) # type: UploadedFile
# Preallocate download button
ui_filedownload_ckpt = streamlit.empty()
# ! Rescale Lora
if ui_fileupload_rescale and ui_fileupload_rescale.name is not None:
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale)
new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor)
new_lora_data = safetensors.torch.save(new_weights, lora_metadata)
lora_file_buffer = io.BytesIO()
lora_file_buffer.write(new_lora_data)
lora_file_buffer.seek(0)
file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0]
output_name = f"{file_name}_rescaled_{new_scale_factor:.2f}.safetensors"
ui_fileupload_rescale.close()
del ui_fileupload_rescale
ui_fileupload_rescale.name = None
ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary")
# ! Remove CLIP Parameters
if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None:
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip)
stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype)
stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata)
lora_file_buffer = io.BytesIO()
lora_file_buffer.write(stripped_lora_data)
lora_file_buffer.seek(0)
file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0]
output_name = f"{file_name}_noclip.safetensors"
ui_fileupload_stripclip.close()
del ui_fileupload_stripclip
ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary")
# ! Convert Checkpoint to Safetensors
if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None:
converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype)
converted_lora_data = safetensors.torch.save(converted_weights)
lora_file_buffer = io.BytesIO()
lora_file_buffer.write(converted_lora_data)
lora_file_buffer.seek(0)
file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0]
output_name = f"{file_name}.safetensors"
ui_fileupload_ckpt.close()
del ui_fileupload_ckpt
ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary")