Spaces:
Sleeping
Sleeping
import io | |
import safetensors | |
import streamlit.file_util | |
from safetensors.torch import serialize | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
from tools import lora_tools, torch_tools | |
# https://huggingface.co/docs/hub/spaces-config-reference | |
streamlit.title("Lora and Embedding Tools") | |
output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0) | |
streamlit.container() | |
col1, col2 = streamlit.columns(2, gap="medium") | |
# A helper method to wipe a download button once invoked | |
def completed_download_callback(): | |
ui_filedownload_rescale.empty() | |
ui_filedownload_stripclip.empty() | |
ui_filedownload_ckpt.empty() | |
with col1: | |
# A tool for rescaling the strength of Lora weights | |
streamlit.html("<h3>Rescale Lora Strength</h3>") | |
ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"]) # type: UploadedFile | |
new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01) | |
# Let's preallocate the download button here so it's in the correct column, we can just add the button later. | |
ui_filedownload_rescale = streamlit.empty() | |
with col2: | |
# A tool for removing CLIP parameters from a Lora file | |
streamlit.html("<h3>Remove CLIP Parameters</h3>") | |
ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"]) # type: UploadedFile | |
# Preallocate download button | |
ui_filedownload_stripclip = streamlit.empty() | |
streamlit.html("<hr>") | |
# A tool for converting a .ckpt file to a .safetensors file | |
streamlit.html("<h3>Convert CKPT to Safetensors (700MB max)</h3>") | |
ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt"]) # type: UploadedFile | |
# Preallocate download button | |
ui_filedownload_ckpt = streamlit.empty() | |
# ! Rescale Lora | |
if ui_fileupload_rescale and ui_fileupload_rescale.name is not None: | |
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale) | |
new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor) | |
new_lora_data = safetensors.torch.save(new_weights, lora_metadata) | |
lora_file_buffer = io.BytesIO() | |
lora_file_buffer.write(new_lora_data) | |
lora_file_buffer.seek(0) | |
file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0] | |
output_name = f"{file_name}_rescaled_{new_scale_factor:.2f}.safetensors" | |
ui_fileupload_rescale.close() | |
del ui_fileupload_rescale | |
ui_fileupload_rescale.name = None | |
ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary") | |
# ! Remove CLIP Parameters | |
if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None: | |
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip) | |
stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype) | |
stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata) | |
lora_file_buffer = io.BytesIO() | |
lora_file_buffer.write(stripped_lora_data) | |
lora_file_buffer.seek(0) | |
file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0] | |
output_name = f"{file_name}_noclip.safetensors" | |
ui_fileupload_stripclip.close() | |
del ui_fileupload_stripclip | |
ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary") | |
# ! Convert Checkpoint to Safetensors | |
if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None: | |
converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype) | |
converted_lora_data = safetensors.torch.save(converted_weights) | |
lora_file_buffer = io.BytesIO() | |
lora_file_buffer.write(converted_lora_data) | |
lora_file_buffer.seek(0) | |
file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0] | |
output_name = f"{file_name}.safetensors" | |
ui_fileupload_ckpt.close() | |
del ui_fileupload_ckpt | |
ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary") | |