File size: 4,177 Bytes
aa5d6d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import io

import safetensors
import streamlit.file_util
from safetensors.torch import serialize
from streamlit.runtime.uploaded_file_manager import UploadedFile

from tools import lora_tools, torch_tools

# https://huggingface.co/docs/hub/spaces-config-reference

streamlit.title("Lora and Embedding Tools")

output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0)
streamlit.container()
col1, col2 = streamlit.columns(2, gap="medium")

# A helper method to wipe a download button once invoked
def completed_download_callback():
	ui_filedownload_rescale.empty()
	ui_filedownload_stripclip.empty()
	ui_filedownload_ckpt.empty()

with col1:
	# A tool for rescaling the strength of Lora weights
	streamlit.html("<h3>Rescale Lora Strength</h3>")
	ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"])  # type: UploadedFile
	new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01)
	
	# Let's preallocate the download button here so it's in the correct column, we can just add the button later.
	ui_filedownload_rescale = streamlit.empty()

with col2:
	# A tool for removing CLIP parameters from a Lora file
	streamlit.html("<h3>Remove CLIP Parameters</h3>")
	ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"])  # type: UploadedFile
	
	# Preallocate download button
	ui_filedownload_stripclip = streamlit.empty()
	
	streamlit.html("<hr>")
	
	# A tool for converting a .ckpt file to a .safetensors file
	streamlit.html("<h3>Convert CKPT to Safetensors (700MB max)</h3>")
	ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt"])  # type: UploadedFile
	
	# Preallocate download button
	ui_filedownload_ckpt = streamlit.empty()

# ! Rescale Lora
if ui_fileupload_rescale and ui_fileupload_rescale.name is not None:
	lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale)
	new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor)
	
	new_lora_data = safetensors.torch.save(new_weights, lora_metadata)
	
	lora_file_buffer = io.BytesIO()
	lora_file_buffer.write(new_lora_data)
	lora_file_buffer.seek(0)
	
	file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0]
	output_name = f"{file_name}_rescaled.safetensors"
	
	ui_fileupload_rescale.close()
	del ui_fileupload_rescale
	ui_fileupload_rescale.name = None
	
	ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary")

# ! Remove CLIP Parameters
if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None:
	lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip)
	stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype)
	
	stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata)
	
	lora_file_buffer = io.BytesIO()
	lora_file_buffer.write(stripped_lora_data)
	lora_file_buffer.seek(0)
	
	file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0]
	output_name = f"{file_name}_noclip.safetensors"
	
	ui_fileupload_stripclip.close()
	del ui_fileupload_stripclip
	
	ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary")

# ! Convert Checkpoint to Safetensors
if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None:
	converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype)
	
	converted_lora_data = safetensors.torch.save(converted_weights)
	
	lora_file_buffer = io.BytesIO()
	lora_file_buffer.write(converted_lora_data)
	lora_file_buffer.seek(0)
	
	file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0]
	output_name = f"{file_name}.safetensors"
	
	ui_fileupload_ckpt.close()
	del ui_fileupload_ckpt
	
	ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary")