kjerk commited on
Commit
aa5d6d0
1 Parent(s): 9c24431

Add initial tools, layout, and config.

Browse files
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # IDEs
7
+ /.idea/
8
+ /.vscode/
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101
+ __pypackages__/
102
+
103
+ # Celery stuff
104
+ celerybeat-schedule
105
+ celerybeat.pid
106
+
107
+ # SageMath parsed files
108
+ *.sage.py
109
+
110
+ # Environments
111
+ .env
112
+ .venv
113
+ env/
114
+ venv/
115
+ ENV/
116
+ env.bak/
117
+ venv.bak/
118
+
119
+ # Spyder project settings
120
+ .spyderproject
121
+ .spyproject
122
+
123
+ # Rope project settings
124
+ .ropeproject
125
+
126
+ # mkdocs documentation
127
+ /site
128
+
129
+ # mypy
130
+ .mypy_cache/
131
+ .dmypy.json
132
+ dmypy.json
133
+
134
+ # Pyre type checker
135
+ .pyre/
136
+
137
+ # pytype static type analyzer
138
+ .pytype/
139
+
140
+ # Cython debug symbols
141
+ cython_debug/
142
+
143
+ /wandb/
144
+ wandb/
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+
3
+ maxUploadSize = 700
README.md CHANGED
@@ -10,4 +10,30 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+ # Lora and Embedding Tools
14
+
15
+ 😻 **Lora and Embedding Tools** is a quick toolbelt to help you manipulate Lora and Text Embedding files. This tool provides several functionalities, including rescaling Lora weights, removing CLIP parameters, converting checkpoint files to safetensors format, and whatever else I decide to add in the future.
16
+
17
+ ## Features
18
+
19
+ - **Rescale Lora Strength**: Adjust the strength of Lora weights by specifying a new scale factor. Rescales the embedded Alpha scale. (No more 0.6, etc)
20
+ - **Remove CLIP Parameters**: Strip out CLIP parameters from a Lora file. If you have an overbaked or overaggressive Lora file, this can rescue it sometimes, or make it more agnostic for other models.
21
+ - **Convert CKPT to Safetensors**: Convert `.ckpt` files to `.safetensors` format to get that pickle smell out of your weights.
22
+
23
+ ## How to Use
24
+
25
+ ### Rescale Lora Strength
26
+
27
+ 1. Specify the new scale factor first.
28
+ 2. Upload a `.safetensors` Lora file, conversion begins immediately.
29
+ 3. Download the rescaled weights.
30
+
31
+ ### Remove CLIP Parameters
32
+
33
+ 1. Upload a `.safetensors` Lora file.
34
+ 2. Download the file with CLIP parameters removed.
35
+
36
+ ### Convert CKPT to Safetensors
37
+
38
+ 1. Upload a `.ckpt` file (maximum size 700MB).
39
+ 2. Download the converted `.safetensors` file.
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import safetensors
4
+ import streamlit.file_util
5
+ from safetensors.torch import serialize
6
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
7
+
8
+ from tools import lora_tools, torch_tools
9
+
10
+ # https://huggingface.co/docs/hub/spaces-config-reference
11
+
12
+ streamlit.title("Lora and Embedding Tools")
13
+
14
+ output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0)
15
+ streamlit.container()
16
+ col1, col2 = streamlit.columns(2, gap="medium")
17
+
18
+ # A helper method to wipe a download button once invoked
19
+ def completed_download_callback():
20
+ ui_filedownload_rescale.empty()
21
+ ui_filedownload_stripclip.empty()
22
+ ui_filedownload_ckpt.empty()
23
+
24
+ with col1:
25
+ # A tool for rescaling the strength of Lora weights
26
+ streamlit.html("<h3>Rescale Lora Strength</h3>")
27
+ ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"]) # type: UploadedFile
28
+ new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01)
29
+
30
+ # Let's preallocate the download button here so it's in the correct column, we can just add the button later.
31
+ ui_filedownload_rescale = streamlit.empty()
32
+
33
+ with col2:
34
+ # A tool for removing CLIP parameters from a Lora file
35
+ streamlit.html("<h3>Remove CLIP Parameters</h3>")
36
+ ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"]) # type: UploadedFile
37
+
38
+ # Preallocate download button
39
+ ui_filedownload_stripclip = streamlit.empty()
40
+
41
+ streamlit.html("<hr>")
42
+
43
+ # A tool for converting a .ckpt file to a .safetensors file
44
+ streamlit.html("<h3>Convert CKPT to Safetensors (700MB max)</h3>")
45
+ ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt"]) # type: UploadedFile
46
+
47
+ # Preallocate download button
48
+ ui_filedownload_ckpt = streamlit.empty()
49
+
50
+ # ! Rescale Lora
51
+ if ui_fileupload_rescale and ui_fileupload_rescale.name is not None:
52
+ lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale)
53
+ new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor)
54
+
55
+ new_lora_data = safetensors.torch.save(new_weights, lora_metadata)
56
+
57
+ lora_file_buffer = io.BytesIO()
58
+ lora_file_buffer.write(new_lora_data)
59
+ lora_file_buffer.seek(0)
60
+
61
+ file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0]
62
+ output_name = f"{file_name}_rescaled.safetensors"
63
+
64
+ ui_fileupload_rescale.close()
65
+ del ui_fileupload_rescale
66
+ ui_fileupload_rescale.name = None
67
+
68
+ ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary")
69
+
70
+ # ! Remove CLIP Parameters
71
+ if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None:
72
+ lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip)
73
+ stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype)
74
+
75
+ stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata)
76
+
77
+ lora_file_buffer = io.BytesIO()
78
+ lora_file_buffer.write(stripped_lora_data)
79
+ lora_file_buffer.seek(0)
80
+
81
+ file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0]
82
+ output_name = f"{file_name}_noclip.safetensors"
83
+
84
+ ui_fileupload_stripclip.close()
85
+ del ui_fileupload_stripclip
86
+
87
+ ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary")
88
+
89
+ # ! Convert Checkpoint to Safetensors
90
+ if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None:
91
+ converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype)
92
+
93
+ converted_lora_data = safetensors.torch.save(converted_weights)
94
+
95
+ lora_file_buffer = io.BytesIO()
96
+ lora_file_buffer.write(converted_lora_data)
97
+ lora_file_buffer.seek(0)
98
+
99
+ file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0]
100
+ output_name = f"{file_name}.safetensors"
101
+
102
+ ui_fileupload_ckpt.close()
103
+ del ui_fileupload_ckpt
104
+
105
+ ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary")
pycharm_runner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://discuss.streamlit.io/t/cannot-debug-streamlit-in-pycharm-2023-3-3/61581/2
2
+
3
+ try:
4
+ from streamlit.web import bootstrap
5
+ except ImportError:
6
+ from streamlit import bootstrap
7
+
8
+ real_script = 'app.py'
9
+ bootstrap.run(real_script, f'pycharm_runner.py {real_script}', [], {})
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ safetensors
2
+ streamlit
3
+ torch
tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ if __name__ == '__main__':
2
+ print('__main__ not allowed in modules')
tools/lora_tools.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+
4
+ import safetensors
5
+ import torch
6
+ from safetensors.torch import serialize
7
+
8
+ from .torch_tools import get_target_dtype_ref
9
+
10
+ def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict:
11
+ # This is a simple file structure, the first 8 bytes are the metadata length.
12
+ # Read (length) bytes starting from [8] to get the metadata (a json string).
13
+ lora_upload.seek(0)
14
+
15
+ metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little')
16
+
17
+ lora_upload.seek(8)
18
+ metadata_raw = lora_upload.read(metadata_length)
19
+
20
+ metadata_raw = metadata_raw.decode("utf-8")
21
+ metadata_raw = metadata_raw.strip()
22
+ metadata_dict = json.loads(metadata_raw)
23
+
24
+ # Rewind the buffer to the start, we were just peeking at the metadata.
25
+ lora_upload.seek(0)
26
+
27
+ return metadata_dict.get('__metadata__', {})
28
+
29
+ def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict:
30
+ output_dtype = get_target_dtype_ref(output_dtype)
31
+
32
+ loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
33
+
34
+ initial_tensors = {}
35
+ for tensor_pair in loaded_tensors.items():
36
+ key, tensor = tensor_pair
37
+ initial_tensors[key] = tensor.to(dtype=torch.float32)
38
+
39
+ new_tensors = {}
40
+ for key, val in initial_tensors.items():
41
+ if key.endswith(".alpha"):
42
+ val *= target_weight
43
+ new_tensors[key] = val.to(dtype=output_dtype)
44
+
45
+ return new_tensors
46
+
47
+ def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict:
48
+ output_dtype = get_target_dtype_ref(output_dtype)
49
+
50
+ loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
51
+
52
+ initial_tensors = {}
53
+ for tensor_pair in loaded_tensors.items():
54
+ key, tensor = tensor_pair
55
+ initial_tensors[key] = tensor.to(dtype=torch.float32)
56
+
57
+ filtered_tensors = {}
58
+ for key, val in initial_tensors.items():
59
+ if key.startswith("lora_te1") or key.startswith("lora_te2"):
60
+ continue
61
+ filtered_tensors[key] = val.to(dtype=output_dtype)
62
+
63
+ return filtered_tensors
64
+
65
+ if __name__ == '__main__':
66
+ print('__main__ not allowed in modules')
tools/torch_tools.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import torch
4
+
5
+ def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
6
+ if isinstance(target_dtype, torch.dtype):
7
+ return target_dtype
8
+
9
+ if target_dtype == "float16":
10
+ return torch.float16
11
+ elif target_dtype == "float32":
12
+ return torch.float32
13
+ elif target_dtype == "bfloat16":
14
+ return torch.bfloat16
15
+ else:
16
+ raise ValueError(f"Invalid target_dtype: {target_dtype}")
17
+
18
+ def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
19
+ target_dtype = get_target_dtype_ref(target_dtype)
20
+ ckpt_data = ckpt_upload.getvalue()
21
+
22
+ # Load the checkpoint
23
+ checkpoint = torch.load(ckpt_data, map_location="cpu")
24
+
25
+ # Convert the checkpoint to a dictionary of tensors
26
+ tensor_dict = {}
27
+ for key, val in checkpoint.items():
28
+ tensor_dict[key] = val.to(dtype=target_dtype)
29
+
30
+ return tensor_dict
31
+
32
+ if __name__ == '__main__':
33
+ print('__main__ not allowed in modules')