Spaces:
Sleeping
Sleeping
kjerk
commited on
Commit
•
aa5d6d0
1
Parent(s):
9c24431
Add initial tools, layout, and config.
Browse files- .gitignore +144 -0
- .streamlit/config.toml +3 -0
- README.md +27 -1
- app.py +105 -0
- pycharm_runner.py +9 -0
- requirements.txt +3 -0
- tools/__init__.py +2 -0
- tools/lora_tools.py +66 -0
- tools/torch_tools.py +33 -0
.gitignore
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# IDEs
|
7 |
+
/.idea/
|
8 |
+
/.vscode/
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
136 |
+
|
137 |
+
# pytype static type analyzer
|
138 |
+
.pytype/
|
139 |
+
|
140 |
+
# Cython debug symbols
|
141 |
+
cython_debug/
|
142 |
+
|
143 |
+
/wandb/
|
144 |
+
wandb/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
|
3 |
+
maxUploadSize = 700
|
README.md
CHANGED
@@ -10,4 +10,30 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
# Lora and Embedding Tools
|
14 |
+
|
15 |
+
😻 **Lora and Embedding Tools** is a quick toolbelt to help you manipulate Lora and Text Embedding files. This tool provides several functionalities, including rescaling Lora weights, removing CLIP parameters, converting checkpoint files to safetensors format, and whatever else I decide to add in the future.
|
16 |
+
|
17 |
+
## Features
|
18 |
+
|
19 |
+
- **Rescale Lora Strength**: Adjust the strength of Lora weights by specifying a new scale factor. Rescales the embedded Alpha scale. (No more 0.6, etc)
|
20 |
+
- **Remove CLIP Parameters**: Strip out CLIP parameters from a Lora file. If you have an overbaked or overaggressive Lora file, this can rescue it sometimes, or make it more agnostic for other models.
|
21 |
+
- **Convert CKPT to Safetensors**: Convert `.ckpt` files to `.safetensors` format to get that pickle smell out of your weights.
|
22 |
+
|
23 |
+
## How to Use
|
24 |
+
|
25 |
+
### Rescale Lora Strength
|
26 |
+
|
27 |
+
1. Specify the new scale factor first.
|
28 |
+
2. Upload a `.safetensors` Lora file, conversion begins immediately.
|
29 |
+
3. Download the rescaled weights.
|
30 |
+
|
31 |
+
### Remove CLIP Parameters
|
32 |
+
|
33 |
+
1. Upload a `.safetensors` Lora file.
|
34 |
+
2. Download the file with CLIP parameters removed.
|
35 |
+
|
36 |
+
### Convert CKPT to Safetensors
|
37 |
+
|
38 |
+
1. Upload a `.ckpt` file (maximum size 700MB).
|
39 |
+
2. Download the converted `.safetensors` file.
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
+
import safetensors
|
4 |
+
import streamlit.file_util
|
5 |
+
from safetensors.torch import serialize
|
6 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
7 |
+
|
8 |
+
from tools import lora_tools, torch_tools
|
9 |
+
|
10 |
+
# https://huggingface.co/docs/hub/spaces-config-reference
|
11 |
+
|
12 |
+
streamlit.title("Lora and Embedding Tools")
|
13 |
+
|
14 |
+
output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0)
|
15 |
+
streamlit.container()
|
16 |
+
col1, col2 = streamlit.columns(2, gap="medium")
|
17 |
+
|
18 |
+
# A helper method to wipe a download button once invoked
|
19 |
+
def completed_download_callback():
|
20 |
+
ui_filedownload_rescale.empty()
|
21 |
+
ui_filedownload_stripclip.empty()
|
22 |
+
ui_filedownload_ckpt.empty()
|
23 |
+
|
24 |
+
with col1:
|
25 |
+
# A tool for rescaling the strength of Lora weights
|
26 |
+
streamlit.html("<h3>Rescale Lora Strength</h3>")
|
27 |
+
ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"]) # type: UploadedFile
|
28 |
+
new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01)
|
29 |
+
|
30 |
+
# Let's preallocate the download button here so it's in the correct column, we can just add the button later.
|
31 |
+
ui_filedownload_rescale = streamlit.empty()
|
32 |
+
|
33 |
+
with col2:
|
34 |
+
# A tool for removing CLIP parameters from a Lora file
|
35 |
+
streamlit.html("<h3>Remove CLIP Parameters</h3>")
|
36 |
+
ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"]) # type: UploadedFile
|
37 |
+
|
38 |
+
# Preallocate download button
|
39 |
+
ui_filedownload_stripclip = streamlit.empty()
|
40 |
+
|
41 |
+
streamlit.html("<hr>")
|
42 |
+
|
43 |
+
# A tool for converting a .ckpt file to a .safetensors file
|
44 |
+
streamlit.html("<h3>Convert CKPT to Safetensors (700MB max)</h3>")
|
45 |
+
ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt"]) # type: UploadedFile
|
46 |
+
|
47 |
+
# Preallocate download button
|
48 |
+
ui_filedownload_ckpt = streamlit.empty()
|
49 |
+
|
50 |
+
# ! Rescale Lora
|
51 |
+
if ui_fileupload_rescale and ui_fileupload_rescale.name is not None:
|
52 |
+
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale)
|
53 |
+
new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor)
|
54 |
+
|
55 |
+
new_lora_data = safetensors.torch.save(new_weights, lora_metadata)
|
56 |
+
|
57 |
+
lora_file_buffer = io.BytesIO()
|
58 |
+
lora_file_buffer.write(new_lora_data)
|
59 |
+
lora_file_buffer.seek(0)
|
60 |
+
|
61 |
+
file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0]
|
62 |
+
output_name = f"{file_name}_rescaled.safetensors"
|
63 |
+
|
64 |
+
ui_fileupload_rescale.close()
|
65 |
+
del ui_fileupload_rescale
|
66 |
+
ui_fileupload_rescale.name = None
|
67 |
+
|
68 |
+
ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary")
|
69 |
+
|
70 |
+
# ! Remove CLIP Parameters
|
71 |
+
if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None:
|
72 |
+
lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip)
|
73 |
+
stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype)
|
74 |
+
|
75 |
+
stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata)
|
76 |
+
|
77 |
+
lora_file_buffer = io.BytesIO()
|
78 |
+
lora_file_buffer.write(stripped_lora_data)
|
79 |
+
lora_file_buffer.seek(0)
|
80 |
+
|
81 |
+
file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0]
|
82 |
+
output_name = f"{file_name}_noclip.safetensors"
|
83 |
+
|
84 |
+
ui_fileupload_stripclip.close()
|
85 |
+
del ui_fileupload_stripclip
|
86 |
+
|
87 |
+
ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary")
|
88 |
+
|
89 |
+
# ! Convert Checkpoint to Safetensors
|
90 |
+
if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None:
|
91 |
+
converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype)
|
92 |
+
|
93 |
+
converted_lora_data = safetensors.torch.save(converted_weights)
|
94 |
+
|
95 |
+
lora_file_buffer = io.BytesIO()
|
96 |
+
lora_file_buffer.write(converted_lora_data)
|
97 |
+
lora_file_buffer.seek(0)
|
98 |
+
|
99 |
+
file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0]
|
100 |
+
output_name = f"{file_name}.safetensors"
|
101 |
+
|
102 |
+
ui_fileupload_ckpt.close()
|
103 |
+
del ui_fileupload_ckpt
|
104 |
+
|
105 |
+
ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary")
|
pycharm_runner.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://discuss.streamlit.io/t/cannot-debug-streamlit-in-pycharm-2023-3-3/61581/2
|
2 |
+
|
3 |
+
try:
|
4 |
+
from streamlit.web import bootstrap
|
5 |
+
except ImportError:
|
6 |
+
from streamlit import bootstrap
|
7 |
+
|
8 |
+
real_script = 'app.py'
|
9 |
+
bootstrap.run(real_script, f'pycharm_runner.py {real_script}', [], {})
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
safetensors
|
2 |
+
streamlit
|
3 |
+
torch
|
tools/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
if __name__ == '__main__':
|
2 |
+
print('__main__ not allowed in modules')
|
tools/lora_tools.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import json
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import serialize
|
7 |
+
|
8 |
+
from .torch_tools import get_target_dtype_ref
|
9 |
+
|
10 |
+
def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict:
|
11 |
+
# This is a simple file structure, the first 8 bytes are the metadata length.
|
12 |
+
# Read (length) bytes starting from [8] to get the metadata (a json string).
|
13 |
+
lora_upload.seek(0)
|
14 |
+
|
15 |
+
metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little')
|
16 |
+
|
17 |
+
lora_upload.seek(8)
|
18 |
+
metadata_raw = lora_upload.read(metadata_length)
|
19 |
+
|
20 |
+
metadata_raw = metadata_raw.decode("utf-8")
|
21 |
+
metadata_raw = metadata_raw.strip()
|
22 |
+
metadata_dict = json.loads(metadata_raw)
|
23 |
+
|
24 |
+
# Rewind the buffer to the start, we were just peeking at the metadata.
|
25 |
+
lora_upload.seek(0)
|
26 |
+
|
27 |
+
return metadata_dict.get('__metadata__', {})
|
28 |
+
|
29 |
+
def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict:
|
30 |
+
output_dtype = get_target_dtype_ref(output_dtype)
|
31 |
+
|
32 |
+
loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
|
33 |
+
|
34 |
+
initial_tensors = {}
|
35 |
+
for tensor_pair in loaded_tensors.items():
|
36 |
+
key, tensor = tensor_pair
|
37 |
+
initial_tensors[key] = tensor.to(dtype=torch.float32)
|
38 |
+
|
39 |
+
new_tensors = {}
|
40 |
+
for key, val in initial_tensors.items():
|
41 |
+
if key.endswith(".alpha"):
|
42 |
+
val *= target_weight
|
43 |
+
new_tensors[key] = val.to(dtype=output_dtype)
|
44 |
+
|
45 |
+
return new_tensors
|
46 |
+
|
47 |
+
def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict:
|
48 |
+
output_dtype = get_target_dtype_ref(output_dtype)
|
49 |
+
|
50 |
+
loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
|
51 |
+
|
52 |
+
initial_tensors = {}
|
53 |
+
for tensor_pair in loaded_tensors.items():
|
54 |
+
key, tensor = tensor_pair
|
55 |
+
initial_tensors[key] = tensor.to(dtype=torch.float32)
|
56 |
+
|
57 |
+
filtered_tensors = {}
|
58 |
+
for key, val in initial_tensors.items():
|
59 |
+
if key.startswith("lora_te1") or key.startswith("lora_te2"):
|
60 |
+
continue
|
61 |
+
filtered_tensors[key] = val.to(dtype=output_dtype)
|
62 |
+
|
63 |
+
return filtered_tensors
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
print('__main__ not allowed in modules')
|
tools/torch_tools.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
|
6 |
+
if isinstance(target_dtype, torch.dtype):
|
7 |
+
return target_dtype
|
8 |
+
|
9 |
+
if target_dtype == "float16":
|
10 |
+
return torch.float16
|
11 |
+
elif target_dtype == "float32":
|
12 |
+
return torch.float32
|
13 |
+
elif target_dtype == "bfloat16":
|
14 |
+
return torch.bfloat16
|
15 |
+
else:
|
16 |
+
raise ValueError(f"Invalid target_dtype: {target_dtype}")
|
17 |
+
|
18 |
+
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
|
19 |
+
target_dtype = get_target_dtype_ref(target_dtype)
|
20 |
+
ckpt_data = ckpt_upload.getvalue()
|
21 |
+
|
22 |
+
# Load the checkpoint
|
23 |
+
checkpoint = torch.load(ckpt_data, map_location="cpu")
|
24 |
+
|
25 |
+
# Convert the checkpoint to a dictionary of tensors
|
26 |
+
tensor_dict = {}
|
27 |
+
for key, val in checkpoint.items():
|
28 |
+
tensor_dict[key] = val.to(dtype=target_dtype)
|
29 |
+
|
30 |
+
return tensor_dict
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
print('__main__ not allowed in modules')
|