Init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +181 -0
- .gitmodules +12 -0
- FAQ.md +10 -0
- LICENSE +21 -0
- README.md +40 -12
- README_origin.md +468 -0
- app.py +6 -0
- assets/glif.svg +40 -0
- assets/lora_ease_ui.png +0 -0
- build_and_push_docker.yaml +8 -0
- config/examples/extract.example.yml +75 -0
- config/examples/generate.example.yaml +60 -0
- config/examples/mod_lora_scale.yaml +48 -0
- config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
- config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
- config/examples/train_lora_flux_24gb.yaml +96 -0
- config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
- config/examples/train_lora_sd35_large_24gb.yaml +97 -0
- config/examples/train_slider.example.yml +230 -0
- docker/Dockerfile +31 -0
- extensions/example/ExampleMergeModels.py +129 -0
- extensions/example/__init__.py +25 -0
- extensions/example/config/config.example.yaml +48 -0
- extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
- extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
- extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
- extensions_built_in/advanced_generator/__init__.py +59 -0
- extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
- extensions_built_in/concept_replacer/ConceptReplacer.py +151 -0
- extensions_built_in/concept_replacer/__init__.py +26 -0
- extensions_built_in/concept_replacer/config/train.example.yaml +91 -0
- extensions_built_in/dataset_tools/DatasetTools.py +20 -0
- extensions_built_in/dataset_tools/SuperTagger.py +196 -0
- extensions_built_in/dataset_tools/SyncFromCollection.py +131 -0
- extensions_built_in/dataset_tools/__init__.py +43 -0
- extensions_built_in/dataset_tools/tools/caption.py +53 -0
- extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py +187 -0
- extensions_built_in/dataset_tools/tools/fuyu_utils.py +66 -0
- extensions_built_in/dataset_tools/tools/image_tools.py +49 -0
- extensions_built_in/dataset_tools/tools/llava_utils.py +85 -0
- extensions_built_in/dataset_tools/tools/sync_tools.py +279 -0
- extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +235 -0
- extensions_built_in/image_reference_slider_trainer/__init__.py +25 -0
- extensions_built_in/image_reference_slider_trainer/config/train.example.yaml +107 -0
- extensions_built_in/sd_trainer/SDTrainer.py +1679 -0
- extensions_built_in/sd_trainer/__init__.py +30 -0
- extensions_built_in/sd_trainer/config/train.example.yaml +91 -0
- extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py +533 -0
- extensions_built_in/ultimate_slider_trainer/__init__.py +25 -0
- extensions_built_in/ultimate_slider_trainer/config/train.example.yaml +107 -0
.gitignore
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
|
162 |
+
/env.sh
|
163 |
+
/models
|
164 |
+
/custom/*
|
165 |
+
!/custom/.gitkeep
|
166 |
+
/.tmp
|
167 |
+
/venv.bkp
|
168 |
+
/venv.*
|
169 |
+
/config/*
|
170 |
+
!/config/examples
|
171 |
+
!/config/_PUT_YOUR_CONFIGS_HERE).txt
|
172 |
+
/output/*
|
173 |
+
!/output/.gitkeep
|
174 |
+
/extensions/*
|
175 |
+
!/extensions/example
|
176 |
+
/temp
|
177 |
+
/wandb
|
178 |
+
.vscode/settings.json
|
179 |
+
.DS_Store
|
180 |
+
._.DS_Store
|
181 |
+
merge_file.py
|
.gitmodules
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "repositories/sd-scripts"]
|
2 |
+
path = repositories/sd-scripts
|
3 |
+
url = https://github.com/kohya-ss/sd-scripts.git
|
4 |
+
[submodule "repositories/leco"]
|
5 |
+
path = repositories/leco
|
6 |
+
url = https://github.com/p1atdev/LECO
|
7 |
+
[submodule "repositories/batch_annotator"]
|
8 |
+
path = repositories/batch_annotator
|
9 |
+
url = https://github.com/ostris/batch-annotator
|
10 |
+
[submodule "repositories/ipadapter"]
|
11 |
+
path = repositories/ipadapter
|
12 |
+
url = https://github.com/tencent-ailab/IP-Adapter.git
|
FAQ.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FAQ
|
2 |
+
|
3 |
+
WIP. Will continue to add things as they are needed.
|
4 |
+
|
5 |
+
## FLUX.1 Training
|
6 |
+
|
7 |
+
#### How much VRAM is required to train a lora on FLUX.1?
|
8 |
+
|
9 |
+
24GB minimum is required.
|
10 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Ostris, LLC
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,40 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FLUX LoRA Training on Modal
|
2 |
+
|
3 |
+
## IMPORTANT - READ THIS
|
4 |
+
|
5 |
+
This Space allows you to train LoRA models for FLUX using Modal's GPU resources. You **must** use your own API tokens for this Space to function.
|
6 |
+
|
7 |
+
## Setup Instructions:
|
8 |
+
|
9 |
+
1. **Create Accounts:** Create free accounts on [Hugging Face](https://huggingface.co), [Modal](https://modal.com), and [Weights & Biases](https://wandb.ai) (if you want to save training info).
|
10 |
+
2. **Get API Tokens:**
|
11 |
+
* **Hugging Face:** Obtain a "write" access token from your Hugging Face [settings/tokens](https://huggingface.co/settings/tokens).
|
12 |
+
* **Modal:** Get your Modal API token from your [Modal dashboard](https://modal.com/).
|
13 |
+
* **WandB:** Generate a WandB API key from your Weights & Biases settings if you plan to use WandB.
|
14 |
+
3. **Duplicate This Space:** Duplicate (clone) this Hugging Face Space to your own account.
|
15 |
+
4. **Add API Tokens as Secrets:** In your duplicated space, navigate to "Settings" -> "Variables and Secrets" and add:
|
16 |
+
* `HF_TOKEN`: Your Hugging Face write access token.
|
17 |
+
* `WANDB_API_KEY`: Your Weights & Biases API key.
|
18 |
+
5. **Upload your dataset**
|
19 |
+
* You must upload your dataset to the `/root/ai-toolkit` folder. The images can be in a subfolder but that folder's path needs to begin with `/root/ai-toolkit/`. For instance: `/root/ai-toolkit/my-dataset`
|
20 |
+
* Make sure the image file names match with a corresponding text caption in the same folder. `image1.jpg` and `image1.txt`
|
21 |
+
* You can upload a zip file of your dataset, or just a collection of images and text files.
|
22 |
+
6. **Customize and Train:**
|
23 |
+
* Go to the `App` tab, and use the Gradio interface to train your LoRA.
|
24 |
+
* Enter the required information, dataset, and configure the training parameters.
|
25 |
+
* Choose to upload a zip file or multiple images
|
26 |
+
* Make sure your image file names match with a corresponding `.txt` file.
|
27 |
+
* Click the "Start Training" button and wait for the training to complete (check Modal for logs, WandB for training data).
|
28 |
+
* The UI will automatically upload the file(s) to the Modal compute environment
|
29 |
+
7. **View Results:**
|
30 |
+
* Trained LoRA models will be automatically pushed to your Hugging Face account if you enable the option and have the necessary write token set
|
31 |
+
* Samples, Logs, optimizer and other training information will be stored on WandB if enabled.
|
32 |
+
* Models, optimizer, and samples are always stored in `Storage > flux-lora-models` on Modal.
|
33 |
+
|
34 |
+
## Notes
|
35 |
+
|
36 |
+
* Training data will always be located in a dataset folder under `/root/ai-toolkit/your-dataset` in your config, this is required to train correctly. If you upload a folder name `my-dataset` as a zip, the folder path to reference will be: `/root/ai-toolkit/my-dataset`.
|
37 |
+
* Make sure the images have a corresponding text file with the same name
|
38 |
+
* Training and downloading samples in this model will take a while be patient, especially in low vram mode
|
39 |
+
|
40 |
+
If you encounter any problems, please open a new issue on the [ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) github
|
README_origin.md
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AI Toolkit by Ostris
|
2 |
+
|
3 |
+
## IMPORTANT NOTE - READ THIS
|
4 |
+
This is my research repo. I do a lot of experiments in it and it is possible that I will break things.
|
5 |
+
If something breaks, checkout an earlier commit. This repo can train a lot of things, and it is
|
6 |
+
hard to keep up with all of them.
|
7 |
+
|
8 |
+
## Support my work
|
9 |
+
|
10 |
+
<a href="https://glif.app" target="_blank">
|
11 |
+
<img alt="glif.app" src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/glif.svg?v=1" width="256" height="auto">
|
12 |
+
</a>
|
13 |
+
|
14 |
+
|
15 |
+
My work on this project would not be possible without the amazing support of [Glif](https://glif.app/) and everyone on the
|
16 |
+
team. If you want to support me, support Glif. [Join the site](https://glif.app/),
|
17 |
+
[Join us on Discord](https://discord.com/invite/nuR9zZ2nsh), [follow us on Twitter](https://x.com/heyglif)
|
18 |
+
and come make some cool stuff with us
|
19 |
+
|
20 |
+
## Installation
|
21 |
+
|
22 |
+
Requirements:
|
23 |
+
- python >3.10
|
24 |
+
- Nvidia GPU with enough ram to do what you need
|
25 |
+
- python venv
|
26 |
+
- git
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
Linux:
|
31 |
+
```bash
|
32 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
33 |
+
cd ai-toolkit
|
34 |
+
git submodule update --init --recursive
|
35 |
+
python3 -m venv venv
|
36 |
+
source venv/bin/activate
|
37 |
+
# .\venv\Scripts\activate on windows
|
38 |
+
# install torch first
|
39 |
+
pip3 install torch
|
40 |
+
pip3 install -r requirements.txt
|
41 |
+
```
|
42 |
+
|
43 |
+
Windows:
|
44 |
+
```bash
|
45 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
46 |
+
cd ai-toolkit
|
47 |
+
git submodule update --init --recursive
|
48 |
+
python -m venv venv
|
49 |
+
.\venv\Scripts\activate
|
50 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
## FLUX.1 Training
|
55 |
+
|
56 |
+
### Tutorial
|
57 |
+
|
58 |
+
To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
|
59 |
+
|
60 |
+
|
61 |
+
### Requirements
|
62 |
+
You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
|
63 |
+
your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
|
64 |
+
the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
|
65 |
+
but there are some reports of a bug when running on windows natively.
|
66 |
+
I have only tested on linux for now. This is still extremely experimental
|
67 |
+
and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
|
68 |
+
|
69 |
+
### FLUX.1-dev
|
70 |
+
|
71 |
+
FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
|
72 |
+
non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
|
73 |
+
Otherwise, this will fail. Here are the required steps to setup a license.
|
74 |
+
|
75 |
+
1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
76 |
+
2. Make a file named `.env` in the root on this folder
|
77 |
+
3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
|
78 |
+
|
79 |
+
### FLUX.1-schnell
|
80 |
+
|
81 |
+
FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
|
82 |
+
However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
|
83 |
+
It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
|
84 |
+
|
85 |
+
To use it, You just need to add the assistant to the `model` section of your config file like so:
|
86 |
+
|
87 |
+
```yaml
|
88 |
+
model:
|
89 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
90 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
|
91 |
+
is_flux: true
|
92 |
+
quantize: true
|
93 |
+
```
|
94 |
+
|
95 |
+
You also need to adjust your sample steps since schnell does not require as many
|
96 |
+
|
97 |
+
```yaml
|
98 |
+
sample:
|
99 |
+
guidance_scale: 1 # schnell does not do guidance
|
100 |
+
sample_steps: 4 # 1 - 4 works well
|
101 |
+
```
|
102 |
+
|
103 |
+
### Training
|
104 |
+
1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
|
105 |
+
2. Edit the file following the comments in the file
|
106 |
+
3. Run the file like so `python run.py config/whatever_you_want.yml`
|
107 |
+
|
108 |
+
A folder with the name and the training folder from the config file will be created when you start. It will have all
|
109 |
+
checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
|
110 |
+
from the last checkpoint.
|
111 |
+
|
112 |
+
IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
|
113 |
+
|
114 |
+
### Need help?
|
115 |
+
|
116 |
+
Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
|
117 |
+
and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
|
118 |
+
and I will answer when I can.
|
119 |
+
|
120 |
+
## Gradio UI
|
121 |
+
|
122 |
+
To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
|
123 |
+
|
124 |
+
```bash
|
125 |
+
cd ai-toolkit #in case you are not yet in the ai-toolkit folder
|
126 |
+
huggingface-cli login #provide a `write` token to publish your LoRA at the end
|
127 |
+
python flux_train_ui.py
|
128 |
+
```
|
129 |
+
|
130 |
+
You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
|
131 |
+
![image](assets/lora_ease_ui.png)
|
132 |
+
|
133 |
+
|
134 |
+
## Training in RunPod
|
135 |
+
Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
|
136 |
+
> You need a minimum of 24GB VRAM, pick a GPU by your preference.
|
137 |
+
|
138 |
+
#### Example config ($0.5/hr):
|
139 |
+
- 1x A40 (48 GB VRAM)
|
140 |
+
- 19 vCPU 100 GB RAM
|
141 |
+
|
142 |
+
#### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
|
143 |
+
- ~120 GB Disk
|
144 |
+
- ~120 GB Pod Volume
|
145 |
+
- Start Jupyter Notebook
|
146 |
+
|
147 |
+
### 1. Setup
|
148 |
+
```
|
149 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
150 |
+
cd ai-toolkit
|
151 |
+
git submodule update --init --recursive
|
152 |
+
python -m venv venv
|
153 |
+
source venv/bin/activate
|
154 |
+
pip install torch
|
155 |
+
pip install -r requirements.txt
|
156 |
+
pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
|
157 |
+
```
|
158 |
+
### 2. Upload your dataset
|
159 |
+
- Create a new folder in the root, name it `dataset` or whatever you like.
|
160 |
+
- Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
|
161 |
+
|
162 |
+
### 3. Login into Hugging Face with an Access Token
|
163 |
+
- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
164 |
+
- Run ```huggingface-cli login``` and paste your token.
|
165 |
+
|
166 |
+
### 4. Training
|
167 |
+
- Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
|
168 |
+
- Edit the config following the comments in the file.
|
169 |
+
- Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
|
170 |
+
- Run the file: ```python run.py config/whatever_you_want.yml```.
|
171 |
+
|
172 |
+
### Screenshot from RunPod
|
173 |
+
<img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
|
174 |
+
|
175 |
+
## Training in Modal
|
176 |
+
|
177 |
+
### 1. Setup
|
178 |
+
#### ai-toolkit:
|
179 |
+
```
|
180 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
181 |
+
cd ai-toolkit
|
182 |
+
git submodule update --init --recursive
|
183 |
+
python -m venv venv
|
184 |
+
source venv/bin/activate
|
185 |
+
pip install torch
|
186 |
+
pip install -r requirements.txt
|
187 |
+
pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
|
188 |
+
```
|
189 |
+
#### Modal:
|
190 |
+
- Run `pip install modal` to install the modal Python package.
|
191 |
+
- Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
|
192 |
+
|
193 |
+
#### Hugging Face:
|
194 |
+
- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
195 |
+
- Run `huggingface-cli login` and paste your token.
|
196 |
+
|
197 |
+
### 2. Upload your dataset
|
198 |
+
- Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
|
199 |
+
|
200 |
+
### 3. Configs
|
201 |
+
- Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
|
202 |
+
- Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
|
203 |
+
|
204 |
+
### 4. Edit run_modal.py
|
205 |
+
- Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
|
206 |
+
|
207 |
+
```
|
208 |
+
code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
|
209 |
+
```
|
210 |
+
- Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
|
211 |
+
|
212 |
+
### 5. Training
|
213 |
+
- Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
|
214 |
+
- You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
|
215 |
+
- Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
|
216 |
+
|
217 |
+
### 6. Saving the model
|
218 |
+
- Check contents of the volume by running `modal volume ls flux-lora-models`.
|
219 |
+
- Download the content by running `modal volume get flux-lora-models your-model-name`.
|
220 |
+
- Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
|
221 |
+
|
222 |
+
### Screenshot from Modal
|
223 |
+
|
224 |
+
<img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
|
225 |
+
|
226 |
+
---
|
227 |
+
|
228 |
+
## Dataset Preparation
|
229 |
+
|
230 |
+
Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
|
231 |
+
formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
|
232 |
+
but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
|
233 |
+
You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
|
234 |
+
replaced.
|
235 |
+
|
236 |
+
Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
|
237 |
+
The loader will automatically resize them and can handle varying aspect ratios.
|
238 |
+
|
239 |
+
|
240 |
+
## Training Specific Layers
|
241 |
+
|
242 |
+
To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
|
243 |
+
used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
|
244 |
+
network kwargs like so:
|
245 |
+
|
246 |
+
```yaml
|
247 |
+
network:
|
248 |
+
type: "lora"
|
249 |
+
linear: 128
|
250 |
+
linear_alpha: 128
|
251 |
+
network_kwargs:
|
252 |
+
only_if_contains:
|
253 |
+
- "transformer.single_transformer_blocks.7.proj_out"
|
254 |
+
- "transformer.single_transformer_blocks.20.proj_out"
|
255 |
+
```
|
256 |
+
|
257 |
+
The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
|
258 |
+
the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
|
259 |
+
For instance to only train the `single_transformer` for FLUX.1, you can use the following:
|
260 |
+
|
261 |
+
```yaml
|
262 |
+
network:
|
263 |
+
type: "lora"
|
264 |
+
linear: 128
|
265 |
+
linear_alpha: 128
|
266 |
+
network_kwargs:
|
267 |
+
only_if_contains:
|
268 |
+
- "transformer.single_transformer_blocks."
|
269 |
+
```
|
270 |
+
|
271 |
+
You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
|
272 |
+
|
273 |
+
|
274 |
+
```yaml
|
275 |
+
network:
|
276 |
+
type: "lora"
|
277 |
+
linear: 128
|
278 |
+
linear_alpha: 128
|
279 |
+
network_kwargs:
|
280 |
+
ignore_if_contains:
|
281 |
+
- "transformer.single_transformer_blocks."
|
282 |
+
```
|
283 |
+
|
284 |
+
`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
|
285 |
+
if will be ignored.
|
286 |
+
|
287 |
+
---
|
288 |
+
|
289 |
+
## EVERYTHING BELOW THIS LINE IS OUTDATED
|
290 |
+
|
291 |
+
It may still work like that, but I have not tested it in a while.
|
292 |
+
|
293 |
+
---
|
294 |
+
|
295 |
+
### Batch Image Generation
|
296 |
+
|
297 |
+
A image generator that can take frompts from a config file or form a txt file and generate them to a
|
298 |
+
folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used
|
299 |
+
for generat batch image generation.
|
300 |
+
It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`.
|
301 |
+
Mere info is in the comments in the example
|
302 |
+
|
303 |
+
---
|
304 |
+
|
305 |
+
### LoRA (lierla), LoCON (LyCORIS) extractor
|
306 |
+
|
307 |
+
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
|
308 |
+
and LoRA (lierla) support. It can do multiple types of extractions in one run.
|
309 |
+
It all runs off a config file, which you can find an example of in `config/examples/extract.example.yml`.
|
310 |
+
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
|
311 |
+
Then you can edit the file to your liking. and call it like so:
|
312 |
+
|
313 |
+
```bash
|
314 |
+
python3 run.py config/whatever_you_want.yml
|
315 |
+
```
|
316 |
+
|
317 |
+
You can also put a full path to a config file, if you want to keep it somewhere else.
|
318 |
+
|
319 |
+
```bash
|
320 |
+
python3 run.py "/home/user/whatever_you_want.yml"
|
321 |
+
```
|
322 |
+
|
323 |
+
More notes on how it works are available in the example config file itself. LoRA and LoCON both support
|
324 |
+
extractions of 'fixed', 'threshold', 'ratio', 'quantile'. I'll update what these do and mean later.
|
325 |
+
Most people used fixed, which is traditional fixed dimension extraction.
|
326 |
+
|
327 |
+
`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
|
328 |
+
|
329 |
+
---
|
330 |
+
|
331 |
+
### LoRA Rescale
|
332 |
+
|
333 |
+
Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
|
334 |
+
A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
|
335 |
+
It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
|
336 |
+
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
|
337 |
+
Then you can edit the file to your liking. and call it like so:
|
338 |
+
|
339 |
+
```bash
|
340 |
+
python3 run.py config/whatever_you_want.yml
|
341 |
+
```
|
342 |
+
|
343 |
+
You can also put a full path to a config file, if you want to keep it somewhere else.
|
344 |
+
|
345 |
+
```bash
|
346 |
+
python3 run.py "/home/user/whatever_you_want.yml"
|
347 |
+
```
|
348 |
+
|
349 |
+
More notes on how it works are available in the example config file itself. This is useful when making
|
350 |
+
all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
|
351 |
+
or even -15 to 15. This will allow you to dile it in so they all have your desired scale
|
352 |
+
|
353 |
+
---
|
354 |
+
|
355 |
+
### LoRA Slider Trainer
|
356 |
+
|
357 |
+
<a target="_blank" href="https://colab.research.google.com/github/ostris/ai-toolkit/blob/main/notebooks/SliderTraining.ipynb">
|
358 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
359 |
+
</a>
|
360 |
+
|
361 |
+
This is how I train most of the recent sliders I have on Civitai, you can check them out in my [Civitai profile](https://civitai.com/user/Ostris/models).
|
362 |
+
It is based off the work by [p1atdev/LECO](https://github.com/p1atdev/LECO) and [rohitgandikota/erasing](https://github.com/rohitgandikota/erasing)
|
363 |
+
But has been heavily modified to create sliders rather than erasing concepts. I have a lot more plans on this, but it is
|
364 |
+
very functional as is. It is also very easy to use. Just copy the example config file in `config/examples/train_slider.example.yml`
|
365 |
+
to the `config` folder and rename it to `whatever_you_want.yml`. Then you can edit the file to your liking. and call it like so:
|
366 |
+
|
367 |
+
```bash
|
368 |
+
python3 run.py config/whatever_you_want.yml
|
369 |
+
```
|
370 |
+
|
371 |
+
There is a lot more information in that example file. You can even run the example as is without any modifications to see
|
372 |
+
how it works. It will create a slider that turns all animals into dogs(neg) or cats(pos). Just run it like so:
|
373 |
+
|
374 |
+
```bash
|
375 |
+
python3 run.py config/examples/train_slider.example.yml
|
376 |
+
```
|
377 |
+
|
378 |
+
And you will be able to see how it works without configuring anything. No datasets are required for this method.
|
379 |
+
I will post an better tutorial soon.
|
380 |
+
|
381 |
+
---
|
382 |
+
|
383 |
+
## Extensions!!
|
384 |
+
|
385 |
+
You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
|
386 |
+
available to them. I will probably use this as the primary development method going
|
387 |
+
forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
|
388 |
+
of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
|
389 |
+
folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
|
390 |
+
enough to get you started. To make an extension, just copy that example and replace all the things you need to.
|
391 |
+
|
392 |
+
|
393 |
+
### Model Merger - Example Extension
|
394 |
+
It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
|
395 |
+
as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
|
396 |
+
mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
|
397 |
+
example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
|
398 |
+
and use it like any other config file.
|
399 |
+
|
400 |
+
## WIP Tools
|
401 |
+
|
402 |
+
|
403 |
+
### VAE (Variational Auto Encoder) Trainer
|
404 |
+
|
405 |
+
This works, but is not ready for others to use and therefore does not have an example config.
|
406 |
+
I am still working on it. I will update this when it is ready.
|
407 |
+
I am adding a lot of features for criteria that I have used in my image enlargement work. A Critic (discriminator),
|
408 |
+
content loss, style loss, and a few more. If you don't know, the VAE
|
409 |
+
for stable diffusion (yes even the MSE one, and SDXL), are horrible at smaller faces and it holds SD back. I will fix this.
|
410 |
+
I'll post more about this later with better examples later, but here is a quick test of a run through with various VAEs.
|
411 |
+
Just went in and out. It is much worse on smaller faces than shown here.
|
412 |
+
|
413 |
+
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
|
414 |
+
|
415 |
+
---
|
416 |
+
|
417 |
+
## TODO
|
418 |
+
- [X] Add proper regs on sliders
|
419 |
+
- [X] Add SDXL support (base model only for now)
|
420 |
+
- [ ] Add plain erasing
|
421 |
+
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
422 |
+
|
423 |
+
---
|
424 |
+
|
425 |
+
## Change Log
|
426 |
+
|
427 |
+
#### 2023-08-05
|
428 |
+
- Huge memory rework and slider rework. Slider training is better thant ever with no more
|
429 |
+
ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
|
430 |
+
accumulation. This makes it much faster and more stable.
|
431 |
+
- Updated the example config to be something more practical and more updated to current methods. It is now
|
432 |
+
a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
|
433 |
+
6GB gpu now. Will test soon to verify.
|
434 |
+
|
435 |
+
|
436 |
+
#### 2021-10-20
|
437 |
+
- Windows support bug fixes
|
438 |
+
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
|
439 |
+
check out the example in the `extensions` folder. Read more about that above.
|
440 |
+
- Model Merging, provided via the example extension.
|
441 |
+
|
442 |
+
#### 2023-08-03
|
443 |
+
Another big refactor to make SD more modular.
|
444 |
+
|
445 |
+
Made batch image generation script
|
446 |
+
|
447 |
+
#### 2023-08-01
|
448 |
+
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
|
449 |
+
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
|
450 |
+
at the moment, so hopefully there are not breaking changes.
|
451 |
+
|
452 |
+
Unfortunately, I am too lazy to write a proper changelog with all the changes.
|
453 |
+
|
454 |
+
I added SDXL training to sliders... but.. it does not work properly.
|
455 |
+
The slider training relies on a model's ability to understand that an unconditional (negative prompt)
|
456 |
+
means you do not want that concept in the output. SDXL does not understand this for whatever reason,
|
457 |
+
which makes separating out
|
458 |
+
concepts within the model hard. I am sure the community will find a way to fix this
|
459 |
+
over time, but for now, it is not
|
460 |
+
going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
|
461 |
+
encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
|
462 |
+
training without every experimental new paper added to it. The KISS principal.
|
463 |
+
|
464 |
+
|
465 |
+
#### 2023-07-30
|
466 |
+
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
|
467 |
+
regularizer. You can set the network multiplier to force spread consistency at high weights
|
468 |
+
|
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from hf_ui import demo
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
demo.launch()
|
assets/glif.svg
ADDED
assets/lora_ease_ui.png
ADDED
build_and_push_docker.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
4 |
+
# wait 2 seconds
|
5 |
+
sleep 2
|
6 |
+
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:latest -f docker/Dockerfile .
|
7 |
+
docker tag aitoolkit:latest ostris/aitoolkit:latest
|
8 |
+
docker push ostris/aitoolkit:latest
|
config/examples/extract.example.yml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# this is in yaml format. You can use json if you prefer
|
3 |
+
# I like both but yaml is easier to read and write
|
4 |
+
# plus it has comments which is nice for documentation
|
5 |
+
job: extract # tells the runner what to do
|
6 |
+
config:
|
7 |
+
# the name will be used to create a folder in the output folder
|
8 |
+
# it will also replace any [name] token in the rest of this config
|
9 |
+
name: name_of_your_model
|
10 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
11 |
+
base_model: "/path/to/base/model.safetensors"
|
12 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
13 |
+
extract_model: "/path/to/model/to/extract/trained.safetensors"
|
14 |
+
# we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
|
15 |
+
output_folder: "/path/to/output/folder"
|
16 |
+
is_v2: false
|
17 |
+
dtype: fp16 # saved dtype
|
18 |
+
device: cpu # cpu, cuda:0, etc
|
19 |
+
|
20 |
+
# processes can be chained like this to run multiple in a row
|
21 |
+
# they must all use same models above, but great for testing different
|
22 |
+
# sizes and typed of extractions. It is much faster as we already have the models loaded
|
23 |
+
process:
|
24 |
+
# process 1
|
25 |
+
- type: locon # locon or lora (locon is lycoris)
|
26 |
+
filename: "[name]_64_32.safetensors" # will be put in output folder
|
27 |
+
dtype: fp16
|
28 |
+
mode: fixed
|
29 |
+
linear: 64
|
30 |
+
conv: 32
|
31 |
+
|
32 |
+
# process 2
|
33 |
+
- type: locon
|
34 |
+
output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
|
35 |
+
mode: ratio
|
36 |
+
linear: 0.2
|
37 |
+
conv: 0.2
|
38 |
+
|
39 |
+
# process 3
|
40 |
+
- type: locon
|
41 |
+
filename: "[name]_ratio_02.safetensors"
|
42 |
+
mode: quantile
|
43 |
+
linear: 0.5
|
44 |
+
conv: 0.5
|
45 |
+
|
46 |
+
# process 4
|
47 |
+
- type: lora # traditional lora extraction (lierla) with linear layers only
|
48 |
+
filename: "[name]_4.safetensors"
|
49 |
+
mode: fixed # fixed, ratio, quantile supported for lora as well
|
50 |
+
linear: 4 # lora dim or rank
|
51 |
+
# no conv for lora
|
52 |
+
|
53 |
+
# process 5
|
54 |
+
- type: lora
|
55 |
+
filename: "[name]_q05.safetensors"
|
56 |
+
mode: quantile
|
57 |
+
linear: 0.5
|
58 |
+
|
59 |
+
# you can put any information you want here, and it will be saved in the model
|
60 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
61 |
+
# in the metadata. The software will include this plus some other information
|
62 |
+
meta:
|
63 |
+
name: "[name]" # [name] gets replaced with the name above
|
64 |
+
description: A short description of your model
|
65 |
+
trigger_words:
|
66 |
+
- put
|
67 |
+
- trigger
|
68 |
+
- words
|
69 |
+
- here
|
70 |
+
version: '0.1'
|
71 |
+
creator:
|
72 |
+
name: Your Name
|
73 |
+
email: [email protected]
|
74 |
+
website: https://yourwebsite.com
|
75 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
config/examples/generate.example.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
job: generate # tells the runner what to do
|
4 |
+
config:
|
5 |
+
name: "generate" # this is not really used anywhere currently but required by runner
|
6 |
+
process:
|
7 |
+
# process 1
|
8 |
+
- type: to_folder # process images to a folder
|
9 |
+
output_folder: "output/gen"
|
10 |
+
device: cuda:0 # cpu, cuda:0, etc
|
11 |
+
generate:
|
12 |
+
# these are your defaults you can override most of them with flags
|
13 |
+
sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
|
14 |
+
width: 1024
|
15 |
+
height: 1024
|
16 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
17 |
+
seed: -1 # -1 is random
|
18 |
+
guidance_scale: 7
|
19 |
+
sample_steps: 20
|
20 |
+
ext: ".png" # .png, .jpg, .jpeg, .webp
|
21 |
+
|
22 |
+
# here ate the flags you can use for prompts. Always start with
|
23 |
+
# your prompt first then add these flags after. You can use as many
|
24 |
+
# like
|
25 |
+
# photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
|
26 |
+
# we will try to support all sd-scripts flags where we can
|
27 |
+
|
28 |
+
# FROM SD-SCRIPTS
|
29 |
+
# --n Treat everything until the next option as a negative prompt.
|
30 |
+
# --w Specify the width of the generated image.
|
31 |
+
# --h Specify the height of the generated image.
|
32 |
+
# --d Specify the seed for the generated image.
|
33 |
+
# --l Specify the CFG scale for the generated image.
|
34 |
+
# --s Specify the number of steps during generation.
|
35 |
+
|
36 |
+
# OURS and some QOL additions
|
37 |
+
# --p2 Prompt for the second text encoder (SDXL only)
|
38 |
+
# --n2 Negative prompt for the second text encoder (SDXL only)
|
39 |
+
# --gr Specify the guidance rescale for the generated image (SDXL only)
|
40 |
+
# --seed Specify the seed for the generated image same as --d
|
41 |
+
# --cfg Specify the CFG scale for the generated image same as --l
|
42 |
+
# --steps Specify the number of steps during generation same as --s
|
43 |
+
|
44 |
+
prompt_file: false # if true a txt file will be created next to images with prompt strings used
|
45 |
+
# prompts can also be a path to a text file with one prompt per line
|
46 |
+
# prompts: "/path/to/prompts.txt"
|
47 |
+
prompts:
|
48 |
+
- "photo of batman"
|
49 |
+
- "photo of superman"
|
50 |
+
- "photo of spiderman"
|
51 |
+
- "photo of a superhero --n batman superman spiderman"
|
52 |
+
|
53 |
+
model:
|
54 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
55 |
+
# name_or_path: "runwayml/stable-diffusion-v1-5"
|
56 |
+
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
|
57 |
+
is_v2: false # for v2 models
|
58 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
59 |
+
is_xl: false # for SDXL models
|
60 |
+
dtype: bf16
|
config/examples/mod_lora_scale.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: mod
|
3 |
+
config:
|
4 |
+
name: name_of_your_model_v1
|
5 |
+
process:
|
6 |
+
- type: rescale_lora
|
7 |
+
# path to your current lora model
|
8 |
+
input_path: "/path/to/lora/lora.safetensors"
|
9 |
+
# output path for your new lora model, can be the same as input_path to replace
|
10 |
+
output_path: "/path/to/lora/output_lora_v1.safetensors"
|
11 |
+
# replaces meta with the meta below (plus minimum meta fields)
|
12 |
+
# if false, we will leave the meta alone except for updating hashes (sd-script hashes)
|
13 |
+
replace_meta: true
|
14 |
+
# how to adjust, we can scale the up_down weights or the alpha
|
15 |
+
# up_down is the default and probably the best, they will both net the same outputs
|
16 |
+
# would only affect rare NaN cases and maybe merging with old merge tools
|
17 |
+
scale_target: 'up_down'
|
18 |
+
# precision to save, fp16 is the default and standard
|
19 |
+
save_dtype: fp16
|
20 |
+
# current_weight is the ideal weight you use as a multiplier when using the lora
|
21 |
+
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
|
22 |
+
# you can do negatives here too if you want to flip the lora
|
23 |
+
current_weight: 6.0
|
24 |
+
# target_weight is the ideal weight you use as a multiplier when using the lora
|
25 |
+
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
|
26 |
+
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
|
27 |
+
target_weight: 1.0
|
28 |
+
|
29 |
+
# base model for the lora
|
30 |
+
# this is just used to add meta so automatic111 knows which model it is for
|
31 |
+
# assume v1.5 if these are not set
|
32 |
+
is_xl: false
|
33 |
+
is_v2: false
|
34 |
+
meta:
|
35 |
+
# this is only used if you set replace_meta to true above
|
36 |
+
name: "[name]" # [name] gets replaced with the name above
|
37 |
+
description: A short description of your lora
|
38 |
+
trigger_words:
|
39 |
+
- put
|
40 |
+
- trigger
|
41 |
+
- words
|
42 |
+
- here
|
43 |
+
version: '0.1'
|
44 |
+
creator:
|
45 |
+
name: Your Name
|
46 |
+
email: [email protected]
|
47 |
+
website: https://yourwebsite.com
|
48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
config/examples/modal/modal_train_lora_flux_24gb.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
datasets:
|
25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
28 |
+
# on windows, escape back slashes with another backslash so
|
29 |
+
# "C:\\path\\to\\images\\folder"
|
30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
32 |
+
caption_ext: "txt"
|
33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
37 |
+
train:
|
38 |
+
batch_size: 1
|
39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
40 |
+
gradient_accumulation_steps: 1
|
41 |
+
train_unet: true
|
42 |
+
train_text_encoder: false # probably won't work with flux
|
43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
44 |
+
noise_scheduler: "flowmatch" # for training only
|
45 |
+
optimizer: "adamw8bit"
|
46 |
+
lr: 1e-4
|
47 |
+
# uncomment this to skip the pre training sample
|
48 |
+
# skip_first_sample: true
|
49 |
+
# uncomment to completely disable sampling
|
50 |
+
# disable_sampling: true
|
51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
52 |
+
# linear_timesteps: true
|
53 |
+
|
54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
55 |
+
ema_config:
|
56 |
+
use_ema: true
|
57 |
+
ema_decay: 0.99
|
58 |
+
|
59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
60 |
+
dtype: bf16
|
61 |
+
model:
|
62 |
+
# huggingface model name or path
|
63 |
+
# if you get an error, or get stuck while downloading,
|
64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
|
65 |
+
# place it like "/root/ai-toolkit/FLUX.1-dev"
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
67 |
+
is_flux: true
|
68 |
+
quantize: true # run 8bit mixed precision
|
69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
70 |
+
sample:
|
71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
72 |
+
sample_every: 250 # sample every this many steps
|
73 |
+
width: 1024
|
74 |
+
height: 1024
|
75 |
+
prompts:
|
76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
88 |
+
neg: "" # not used on flux
|
89 |
+
seed: 42
|
90 |
+
walk_seed: true
|
91 |
+
guidance_scale: 4
|
92 |
+
sample_steps: 20
|
93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
94 |
+
meta:
|
95 |
+
name: "[name]"
|
96 |
+
version: '1.0'
|
config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
datasets:
|
25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
28 |
+
# on windows, escape back slashes with another backslash so
|
29 |
+
# "C:\\path\\to\\images\\folder"
|
30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
32 |
+
caption_ext: "txt"
|
33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
37 |
+
train:
|
38 |
+
batch_size: 1
|
39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
40 |
+
gradient_accumulation_steps: 1
|
41 |
+
train_unet: true
|
42 |
+
train_text_encoder: false # probably won't work with flux
|
43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
44 |
+
noise_scheduler: "flowmatch" # for training only
|
45 |
+
optimizer: "adamw8bit"
|
46 |
+
lr: 1e-4
|
47 |
+
# uncomment this to skip the pre training sample
|
48 |
+
# skip_first_sample: true
|
49 |
+
# uncomment to completely disable sampling
|
50 |
+
# disable_sampling: true
|
51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
52 |
+
# linear_timesteps: true
|
53 |
+
|
54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
55 |
+
ema_config:
|
56 |
+
use_ema: true
|
57 |
+
ema_decay: 0.99
|
58 |
+
|
59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
60 |
+
dtype: bf16
|
61 |
+
model:
|
62 |
+
# huggingface model name or path
|
63 |
+
# if you get an error, or get stuck while downloading,
|
64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
|
65 |
+
# place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
68 |
+
is_flux: true
|
69 |
+
quantize: true # run 8bit mixed precision
|
70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
72 |
+
sample:
|
73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
74 |
+
sample_every: 250 # sample every this many steps
|
75 |
+
width: 1024
|
76 |
+
height: 1024
|
77 |
+
prompts:
|
78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
90 |
+
neg: "" # not used on flux
|
91 |
+
seed: 42
|
92 |
+
walk_seed: true
|
93 |
+
guidance_scale: 1 # schnell does not do guidance
|
94 |
+
sample_steps: 4 # 1 - 4 works well
|
95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
96 |
+
meta:
|
97 |
+
name: "[name]"
|
98 |
+
version: '1.0'
|
config/examples/train_lora_flux_24gb.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation_steps: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with flux
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
# uncomment this to skip the pre training sample
|
51 |
+
# skip_first_sample: true
|
52 |
+
# uncomment to completely disable sampling
|
53 |
+
# disable_sampling: true
|
54 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
55 |
+
# linear_timesteps: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
67 |
+
is_flux: true
|
68 |
+
quantize: true # run 8bit mixed precision
|
69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
70 |
+
sample:
|
71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
72 |
+
sample_every: 250 # sample every this many steps
|
73 |
+
width: 1024
|
74 |
+
height: 1024
|
75 |
+
prompts:
|
76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
88 |
+
neg: "" # not used on flux
|
89 |
+
seed: 42
|
90 |
+
walk_seed: true
|
91 |
+
guidance_scale: 4
|
92 |
+
sample_steps: 20
|
93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
94 |
+
meta:
|
95 |
+
name: "[name]"
|
96 |
+
version: '1.0'
|
config/examples/train_lora_flux_schnell_24gb.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation_steps: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with flux
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
# uncomment this to skip the pre training sample
|
51 |
+
# skip_first_sample: true
|
52 |
+
# uncomment to completely disable sampling
|
53 |
+
# disable_sampling: true
|
54 |
+
# uncomment to use new bell curved weighting. Experimental but may produce better results
|
55 |
+
# linear_timesteps: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
68 |
+
is_flux: true
|
69 |
+
quantize: true # run 8bit mixed precision
|
70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
72 |
+
sample:
|
73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
74 |
+
sample_every: 250 # sample every this many steps
|
75 |
+
width: 1024
|
76 |
+
height: 1024
|
77 |
+
prompts:
|
78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
90 |
+
neg: "" # not used on flux
|
91 |
+
seed: 42
|
92 |
+
walk_seed: true
|
93 |
+
guidance_scale: 1 # schnell does not do guidance
|
94 |
+
sample_steps: 4 # 1 - 4 works well
|
95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
96 |
+
meta:
|
97 |
+
name: "[name]"
|
98 |
+
version: '1.0'
|
config/examples/train_lora_sd35_large_24gb.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
|
3 |
+
job: extension
|
4 |
+
config:
|
5 |
+
# this name will be the folder and filename name
|
6 |
+
name: "my_first_sd3l_lora_v1"
|
7 |
+
process:
|
8 |
+
- type: 'sd_trainer'
|
9 |
+
# root folder to save training sessions/samples/weights
|
10 |
+
training_folder: "output"
|
11 |
+
# uncomment to see performance stats in the terminal every N steps
|
12 |
+
# performance_log_every: 1000
|
13 |
+
device: cuda:0
|
14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
16 |
+
# trigger_word: "p3r5on"
|
17 |
+
network:
|
18 |
+
type: "lora"
|
19 |
+
linear: 16
|
20 |
+
linear_alpha: 16
|
21 |
+
save:
|
22 |
+
dtype: float16 # precision to save
|
23 |
+
save_every: 250 # save every this many steps
|
24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
25 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
26 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
27 |
+
# hf_repo_id: your-username/your-model-slug
|
28 |
+
# hf_private: true #whether the repo is private or public
|
29 |
+
datasets:
|
30 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
31 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
32 |
+
# images will automatically be resized and bucketed into the resolution specified
|
33 |
+
# on windows, escape back slashes with another backslash so
|
34 |
+
# "C:\\path\\to\\images\\folder"
|
35 |
+
- folder_path: "/path/to/images/folder"
|
36 |
+
caption_ext: "txt"
|
37 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
38 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
39 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
40 |
+
resolution: [ 1024 ]
|
41 |
+
train:
|
42 |
+
batch_size: 1
|
43 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
44 |
+
gradient_accumulation_steps: 1
|
45 |
+
train_unet: true
|
46 |
+
train_text_encoder: false # May not fully work with SD3 yet
|
47 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
48 |
+
noise_scheduler: "flowmatch"
|
49 |
+
timestep_type: "linear" # linear or sigmoid
|
50 |
+
optimizer: "adamw8bit"
|
51 |
+
lr: 1e-4
|
52 |
+
# uncomment this to skip the pre training sample
|
53 |
+
# skip_first_sample: true
|
54 |
+
# uncomment to completely disable sampling
|
55 |
+
# disable_sampling: true
|
56 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
57 |
+
# linear_timesteps: true
|
58 |
+
|
59 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
60 |
+
ema_config:
|
61 |
+
use_ema: true
|
62 |
+
ema_decay: 0.99
|
63 |
+
|
64 |
+
# will probably need this if gpu supports it for sd3, other dtypes may not work correctly
|
65 |
+
dtype: bf16
|
66 |
+
model:
|
67 |
+
# huggingface model name or path
|
68 |
+
name_or_path: "stabilityai/stable-diffusion-3.5-large"
|
69 |
+
is_v3: true
|
70 |
+
quantize: true # run 8bit mixed precision
|
71 |
+
sample:
|
72 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
73 |
+
sample_every: 250 # sample every this many steps
|
74 |
+
width: 1024
|
75 |
+
height: 1024
|
76 |
+
prompts:
|
77 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
78 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
79 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
80 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
81 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
82 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
83 |
+
- "a bear building a log cabin in the snow covered mountains"
|
84 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
85 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
86 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
87 |
+
- "a man holding a sign that says, 'this is a sign'"
|
88 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
89 |
+
neg: ""
|
90 |
+
seed: 42
|
91 |
+
walk_seed: true
|
92 |
+
guidance_scale: 4
|
93 |
+
sample_steps: 25
|
94 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
95 |
+
meta:
|
96 |
+
name: "[name]"
|
97 |
+
version: '1.0'
|
config/examples/train_slider.example.yml
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# This is in yaml format. You can use json if you prefer
|
3 |
+
# I like both but yaml is easier to write
|
4 |
+
# Plus it has comments which is nice for documentation
|
5 |
+
# This is the config I use on my sliders, It is solid and tested
|
6 |
+
job: train
|
7 |
+
config:
|
8 |
+
# the name will be used to create a folder in the output folder
|
9 |
+
# it will also replace any [name] token in the rest of this config
|
10 |
+
name: detail_slider_v1
|
11 |
+
# folder will be created with name above in folder below
|
12 |
+
# it can be relative to the project root or absolute
|
13 |
+
training_folder: "output/LoRA"
|
14 |
+
device: cuda:0 # cpu, cuda:0, etc
|
15 |
+
# for tensorboard logging, we will make a subfolder for this job
|
16 |
+
log_dir: "output/.tensorboard"
|
17 |
+
# you can stack processes for other jobs, It is not tested with sliders though
|
18 |
+
# just use one for now
|
19 |
+
process:
|
20 |
+
- type: slider # tells runner to run the slider process
|
21 |
+
# network is the LoRA network for a slider, I recommend to leave this be
|
22 |
+
network:
|
23 |
+
# network type lierla is traditional LoRA that works everywhere, only linear layers
|
24 |
+
type: "lierla"
|
25 |
+
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
26 |
+
linear: 8
|
27 |
+
linear_alpha: 4 # Do about half of rank
|
28 |
+
# training config
|
29 |
+
train:
|
30 |
+
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
31 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
32 |
+
# how many steps to train. More is not always better. I rarely go over 1000
|
33 |
+
steps: 500
|
34 |
+
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
35 |
+
lr: 2e-4
|
36 |
+
# enables gradient checkpoint, saves vram, leave it on
|
37 |
+
gradient_checkpointing: true
|
38 |
+
# train the unet. I recommend leaving this true
|
39 |
+
train_unet: true
|
40 |
+
# train the text encoder. I don't recommend this unless you have a special use case
|
41 |
+
# for sliders we are adjusting representation of the concept (unet),
|
42 |
+
# not the description of it (text encoder)
|
43 |
+
train_text_encoder: false
|
44 |
+
# same as from sd-scripts, not fully tested but should speed up training
|
45 |
+
min_snr_gamma: 5.0
|
46 |
+
# just leave unless you know what you are doing
|
47 |
+
# also supports "dadaptation" but set lr to 1 if you use that,
|
48 |
+
# but it learns too fast and I don't recommend it
|
49 |
+
optimizer: "adamw"
|
50 |
+
# only constant for now
|
51 |
+
lr_scheduler: "constant"
|
52 |
+
# we randomly denoise random num of steps form 1 to this number
|
53 |
+
# while training. Just leave it
|
54 |
+
max_denoising_steps: 40
|
55 |
+
# works great at 1. I do 1 even with my 4090.
|
56 |
+
# higher may not work right with newer single batch stacking code anyway
|
57 |
+
batch_size: 1
|
58 |
+
# bf16 works best if your GPU supports it (modern)
|
59 |
+
dtype: bf16 # fp32, bf16, fp16
|
60 |
+
# if you have it, use it. It is faster and better
|
61 |
+
# torch 2.0 doesnt need xformers anymore, only use if you have lower version
|
62 |
+
# xformers: true
|
63 |
+
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
64 |
+
# although, the way we train sliders is comparative, so it probably won't work anyway
|
65 |
+
noise_offset: 0.0
|
66 |
+
# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
|
67 |
+
|
68 |
+
# the model to train the LoRA network on
|
69 |
+
model:
|
70 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
71 |
+
name_or_path: "runwayml/stable-diffusion-v1-5"
|
72 |
+
is_v2: false # for v2 models
|
73 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
74 |
+
# has some issues with the dual text encoder and the way we train sliders
|
75 |
+
# it works bit weights need to probably be higher to see it.
|
76 |
+
is_xl: false # for SDXL models
|
77 |
+
|
78 |
+
# saving config
|
79 |
+
save:
|
80 |
+
dtype: float16 # precision to save. I recommend float16
|
81 |
+
save_every: 50 # save every this many steps
|
82 |
+
# this will remove step counts more than this number
|
83 |
+
# allows you to save more often in case of a crash without filling up your drive
|
84 |
+
max_step_saves_to_keep: 2
|
85 |
+
|
86 |
+
# sampling config
|
87 |
+
sample:
|
88 |
+
# must match train.noise_scheduler, this is not used here
|
89 |
+
# but may be in future and in other processes
|
90 |
+
sampler: "ddpm"
|
91 |
+
# sample every this many steps
|
92 |
+
sample_every: 20
|
93 |
+
# image size
|
94 |
+
width: 512
|
95 |
+
height: 512
|
96 |
+
# prompts to use for sampling. Do as many as you want, but it slows down training
|
97 |
+
# pick ones that will best represent the concept you are trying to adjust
|
98 |
+
# allows some flags after the prompt
|
99 |
+
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
|
100 |
+
# slide are good tests. will inherit sample.network_multiplier if not set
|
101 |
+
# --n [string] # negative prompt, will inherit sample.neg if not set
|
102 |
+
# Only 75 tokens allowed currently
|
103 |
+
# I like to do a wide positive and negative spread so I can see a good range and stop
|
104 |
+
# early if the network is braking down
|
105 |
+
prompts:
|
106 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
|
107 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
|
108 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
|
109 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
|
110 |
+
- "a golden retriever sitting on a leather couch, --m -5"
|
111 |
+
- "a golden retriever sitting on a leather couch --m -3"
|
112 |
+
- "a golden retriever sitting on a leather couch --m 3"
|
113 |
+
- "a golden retriever sitting on a leather couch --m 5"
|
114 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
|
115 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
|
116 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
|
117 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
|
118 |
+
# negative prompt used on all prompts above as default if they don't have one
|
119 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
|
120 |
+
# seed for sampling. 42 is the answer for everything
|
121 |
+
seed: 42
|
122 |
+
# walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
|
123 |
+
# will start over on next sample_every so s1 is always seed
|
124 |
+
# works well if you use same prompt but want different results
|
125 |
+
walk_seed: false
|
126 |
+
# cfg scale (4 to 10 is good)
|
127 |
+
guidance_scale: 7
|
128 |
+
# sampler steps (20 to 30 is good)
|
129 |
+
sample_steps: 20
|
130 |
+
# default network multiplier for all prompts
|
131 |
+
# since we are training a slider, I recommend overriding this with --m [number]
|
132 |
+
# in the prompts above to get both sides of the slider
|
133 |
+
network_multiplier: 1.0
|
134 |
+
|
135 |
+
# logging information
|
136 |
+
logging:
|
137 |
+
log_every: 10 # log every this many steps
|
138 |
+
use_wandb: false # not supported yet
|
139 |
+
verbose: false # probably done need unless you are debugging
|
140 |
+
|
141 |
+
# slider training config, best for last
|
142 |
+
slider:
|
143 |
+
# resolutions to train on. [ width, height ]. This is less important for sliders
|
144 |
+
# as we are not teaching the model anything it doesn't already know
|
145 |
+
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
|
146 |
+
# and [ 1024, 1024 ] for sd_xl
|
147 |
+
# you can do as many as you want here
|
148 |
+
resolutions:
|
149 |
+
- [ 512, 512 ]
|
150 |
+
# - [ 512, 768 ]
|
151 |
+
# - [ 768, 768 ]
|
152 |
+
# slider training uses 4 combined steps for a single round. This will do it in one gradient
|
153 |
+
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
|
154 |
+
# since we break down batches for gradient accumulation now. so just leave it on.
|
155 |
+
batch_full_slide: true
|
156 |
+
# These are the concepts to train on. You can do as many as you want here,
|
157 |
+
# but they can conflict outweigh each other. Other than experimenting, I recommend
|
158 |
+
# just doing one for good results
|
159 |
+
targets:
|
160 |
+
# target_class is the base concept we are adjusting the representation of
|
161 |
+
# for example, if we are adjusting the representation of a person, we would use "person"
|
162 |
+
# if we are adjusting the representation of a cat, we would use "cat" It is not
|
163 |
+
# a keyword necessarily but what the model understands the concept to represent.
|
164 |
+
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc
|
165 |
+
# it is the models base general understanding of the concept and everything it represents
|
166 |
+
# you can leave it blank to affect everything. In this example, we are adjusting
|
167 |
+
# detail, so we will leave it blank to affect everything
|
168 |
+
- target_class: ""
|
169 |
+
# positive is the prompt for the positive side of the slider.
|
170 |
+
# It is the concept that will be excited and amplified in the model when we slide the slider
|
171 |
+
# to the positive side and forgotten / inverted when we slide
|
172 |
+
# the slider to the negative side. It is generally best to include the target_class in
|
173 |
+
# the prompt. You want it to be the extreme of what you want to train on. For example,
|
174 |
+
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
|
175 |
+
# as the prompt. Not just "fat person"
|
176 |
+
# max 75 tokens for now
|
177 |
+
positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
|
178 |
+
# negative is the prompt for the negative side of the slider and works the same as positive
|
179 |
+
# it does not necessarily work the same as a negative prompt when generating images
|
180 |
+
# these need to be polar opposites.
|
181 |
+
# max 76 tokens for now
|
182 |
+
negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
|
183 |
+
# the loss for this target is multiplied by this number.
|
184 |
+
# if you are doing more than one target it may be good to set less important ones
|
185 |
+
# to a lower number like 0.1 so they don't outweigh the primary target
|
186 |
+
weight: 1.0
|
187 |
+
# shuffle the prompts split by the comma. We will run every combination randomly
|
188 |
+
# this will make the LoRA more robust. You probably want this on unless prompt order
|
189 |
+
# is important for some reason
|
190 |
+
shuffle: true
|
191 |
+
|
192 |
+
|
193 |
+
# anchors are prompts that we will try to hold on to while training the slider
|
194 |
+
# these are NOT necessary and can prevent the slider from converging if not done right
|
195 |
+
# leave them off if you are having issues, but they can help lock the network
|
196 |
+
# on certain concepts to help prevent catastrophic forgetting
|
197 |
+
# you want these to generate an image that is not your target_class, but close to it
|
198 |
+
# is fine as long as it does not directly overlap it.
|
199 |
+
# For example, if you are training on a person smiling,
|
200 |
+
# you could use "a person with a face mask" as an anchor. It is a person, the image is the same
|
201 |
+
# regardless if they are smiling or not, however, the closer the concept is to the target_class
|
202 |
+
# the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
|
203 |
+
# for close concepts, you want to be closer to 0.1 or 0.2
|
204 |
+
# these will slow down training. I am leaving them off for the demo
|
205 |
+
|
206 |
+
# anchors:
|
207 |
+
# - prompt: "a woman"
|
208 |
+
# neg_prompt: "animal"
|
209 |
+
# # the multiplier applied to the LoRA when this is run.
|
210 |
+
# # higher will give it more weight but also help keep the lora from collapsing
|
211 |
+
# multiplier: 1.0
|
212 |
+
# - prompt: "a man"
|
213 |
+
# neg_prompt: "animal"
|
214 |
+
# multiplier: 1.0
|
215 |
+
# - prompt: "a person"
|
216 |
+
# neg_prompt: "animal"
|
217 |
+
# multiplier: 1.0
|
218 |
+
|
219 |
+
# You can put any information you want here, and it will be saved in the model.
|
220 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
221 |
+
# It is saved in the model so be aware of that. The software will include this
|
222 |
+
# plus some other information for you automatically
|
223 |
+
meta:
|
224 |
+
# [name] gets replaced with the name above
|
225 |
+
name: "[name]"
|
226 |
+
# version: '1.0'
|
227 |
+
# creator:
|
228 |
+
# name: Your Name
|
229 |
+
# email: [email protected]
|
230 |
+
# website: https://your.website
|
docker/Dockerfile
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM runpod/base:0.6.2-cuda12.2.0
|
2 |
+
|
3 |
+
LABEL authors="jaret"
|
4 |
+
|
5 |
+
# Install dependencies
|
6 |
+
RUN apt-get update
|
7 |
+
|
8 |
+
WORKDIR /app
|
9 |
+
ARG CACHEBUST=1
|
10 |
+
RUN git clone https://github.com/ostris/ai-toolkit.git && \
|
11 |
+
cd ai-toolkit && \
|
12 |
+
git submodule update --init --recursive
|
13 |
+
|
14 |
+
WORKDIR /app/ai-toolkit
|
15 |
+
|
16 |
+
RUN ln -s /usr/bin/python3 /usr/bin/python
|
17 |
+
RUN python -m pip install -r requirements.txt
|
18 |
+
|
19 |
+
RUN apt-get install -y tmux nvtop htop
|
20 |
+
|
21 |
+
RUN pip install jupyterlab
|
22 |
+
|
23 |
+
# mask workspace
|
24 |
+
RUN mkdir /workspace
|
25 |
+
|
26 |
+
|
27 |
+
# symlink app to workspace
|
28 |
+
RUN ln -s /app/ai-toolkit /workspace/ai-toolkit
|
29 |
+
|
30 |
+
WORKDIR /
|
31 |
+
CMD ["/start.sh"]
|
extensions/example/ExampleMergeModels.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import TYPE_CHECKING
|
5 |
+
from jobs.process import BaseExtensionProcess
|
6 |
+
from toolkit.config_modules import ModelConfig
|
7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
8 |
+
from toolkit.train_tools import get_torch_dtype
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Type check imports. Prevents circular imports
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from jobs import ExtensionJob
|
14 |
+
|
15 |
+
|
16 |
+
# extend standard config classes to add weight
|
17 |
+
class ModelInputConfig(ModelConfig):
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.weight = kwargs.get('weight', 1.0)
|
21 |
+
# overwrite default dtype unless user specifies otherwise
|
22 |
+
# float 32 will give up better precision on the merging functions
|
23 |
+
self.dtype: str = kwargs.get('dtype', 'float32')
|
24 |
+
|
25 |
+
|
26 |
+
def flush():
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
gc.collect()
|
29 |
+
|
30 |
+
|
31 |
+
# this is our main class process
|
32 |
+
class ExampleMergeModels(BaseExtensionProcess):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
process_id: int,
|
36 |
+
job: 'ExtensionJob',
|
37 |
+
config: OrderedDict
|
38 |
+
):
|
39 |
+
super().__init__(process_id, job, config)
|
40 |
+
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
41 |
+
# checking requirements. This is called before the run() function
|
42 |
+
# no loading models or anything like that, it is just for setting up the process
|
43 |
+
# all of your process intensive stuff should be done in the run() function
|
44 |
+
# config will have everything from the process item in the config file
|
45 |
+
|
46 |
+
# convince methods exist on BaseProcess to get config values
|
47 |
+
# if required is set to true and the value is not found it will throw an error
|
48 |
+
# you can pass a default value to get_conf() as well if it was not in the config file
|
49 |
+
# as well as a type to cast the value to
|
50 |
+
self.save_path = self.get_conf('save_path', required=True)
|
51 |
+
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
52 |
+
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
53 |
+
|
54 |
+
# build models to merge list
|
55 |
+
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
56 |
+
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
57 |
+
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
58 |
+
# inbuilt config classes located in toolkit.config_modules as well
|
59 |
+
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
60 |
+
# setup is complete. Don't load anything else here, just setup variables and stuff
|
61 |
+
|
62 |
+
# this is the entire run process be sure to call super().run() first
|
63 |
+
def run(self):
|
64 |
+
# always call first
|
65 |
+
super().run()
|
66 |
+
print(f"Running process: {self.__class__.__name__}")
|
67 |
+
|
68 |
+
# let's adjust our weights first to normalize them so the total is 1.0
|
69 |
+
total_weight = sum([model.weight for model in self.models_to_merge])
|
70 |
+
weight_adjust = 1.0 / total_weight
|
71 |
+
for model in self.models_to_merge:
|
72 |
+
model.weight *= weight_adjust
|
73 |
+
|
74 |
+
output_model: StableDiffusion = None
|
75 |
+
# let's do the merge, it is a good idea to use tqdm to show progress
|
76 |
+
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
77 |
+
# setup model class with our helper class
|
78 |
+
sd_model = StableDiffusion(
|
79 |
+
device=self.device,
|
80 |
+
model_config=model_config,
|
81 |
+
dtype="float32"
|
82 |
+
)
|
83 |
+
# load the model
|
84 |
+
sd_model.load_model()
|
85 |
+
|
86 |
+
# adjust the weight of the text encoder
|
87 |
+
if isinstance(sd_model.text_encoder, list):
|
88 |
+
# sdxl model
|
89 |
+
for text_encoder in sd_model.text_encoder:
|
90 |
+
for key, value in text_encoder.state_dict().items():
|
91 |
+
value *= model_config.weight
|
92 |
+
else:
|
93 |
+
# normal model
|
94 |
+
for key, value in sd_model.text_encoder.state_dict().items():
|
95 |
+
value *= model_config.weight
|
96 |
+
# adjust the weights of the unet
|
97 |
+
for key, value in sd_model.unet.state_dict().items():
|
98 |
+
value *= model_config.weight
|
99 |
+
|
100 |
+
if output_model is None:
|
101 |
+
# use this one as the base
|
102 |
+
output_model = sd_model
|
103 |
+
else:
|
104 |
+
# merge the models
|
105 |
+
# text encoder
|
106 |
+
if isinstance(output_model.text_encoder, list):
|
107 |
+
# sdxl model
|
108 |
+
for i, text_encoder in enumerate(output_model.text_encoder):
|
109 |
+
for key, value in text_encoder.state_dict().items():
|
110 |
+
value += sd_model.text_encoder[i].state_dict()[key]
|
111 |
+
else:
|
112 |
+
# normal model
|
113 |
+
for key, value in output_model.text_encoder.state_dict().items():
|
114 |
+
value += sd_model.text_encoder.state_dict()[key]
|
115 |
+
# unet
|
116 |
+
for key, value in output_model.unet.state_dict().items():
|
117 |
+
value += sd_model.unet.state_dict()[key]
|
118 |
+
|
119 |
+
# remove the model to free memory
|
120 |
+
del sd_model
|
121 |
+
flush()
|
122 |
+
|
123 |
+
# merge loop is done, let's save the model
|
124 |
+
print(f"Saving merged model to {self.save_path}")
|
125 |
+
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
126 |
+
print(f"Saved merged model to {self.save_path}")
|
127 |
+
# do cleanup here
|
128 |
+
del output_model
|
129 |
+
flush()
|
extensions/example/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# We make a subclass of Extension
|
6 |
+
class ExampleMergeExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "example_merge_extension"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Example Merge Extension"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ExampleMergeModels import ExampleMergeModels
|
19 |
+
return ExampleMergeModels
|
20 |
+
|
21 |
+
|
22 |
+
AI_TOOLKIT_EXTENSIONS = [
|
23 |
+
# you can put a list of extensions here
|
24 |
+
ExampleMergeExtension
|
25 |
+
]
|
extensions/example/config/config.example.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# Always include at least one example config file to show how to use your extension.
|
3 |
+
# use plenty of comments so users know how to use it and what everything does
|
4 |
+
|
5 |
+
# all extensions will use this job name
|
6 |
+
job: extension
|
7 |
+
config:
|
8 |
+
name: 'my_awesome_merge'
|
9 |
+
process:
|
10 |
+
# Put your example processes here. This will be passed
|
11 |
+
# to your extension process in the config argument.
|
12 |
+
# the type MUST match your extension uid
|
13 |
+
- type: "example_merge_extension"
|
14 |
+
# save path for the merged model
|
15 |
+
save_path: "output/merge/[name].safetensors"
|
16 |
+
# save type
|
17 |
+
dtype: fp16
|
18 |
+
# device to run it on
|
19 |
+
device: cuda:0
|
20 |
+
# input models can only be SD1.x and SD2.x models for this example (currently)
|
21 |
+
models_to_merge:
|
22 |
+
# weights are relative, total weights will be normalized
|
23 |
+
# for example. If you have 2 models with weight 1.0, they will
|
24 |
+
# both be weighted 0.5. If you have 1 model with weight 1.0 and
|
25 |
+
# another with weight 2.0, the first will be weighted 1/3 and the
|
26 |
+
# second will be weighted 2/3
|
27 |
+
- name_or_path: "input/model1.safetensors"
|
28 |
+
weight: 1.0
|
29 |
+
- name_or_path: "input/model2.safetensors"
|
30 |
+
weight: 1.0
|
31 |
+
- name_or_path: "input/model3.safetensors"
|
32 |
+
weight: 0.3
|
33 |
+
- name_or_path: "input/model4.safetensors"
|
34 |
+
weight: 1.0
|
35 |
+
|
36 |
+
|
37 |
+
# you can put any information you want here, and it will be saved in the model
|
38 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
39 |
+
# in the metadata. The software will include this plus some other information
|
40 |
+
meta:
|
41 |
+
name: "[name]" # [name] gets replaced with the name above
|
42 |
+
description: A short description of your model
|
43 |
+
version: '0.1'
|
44 |
+
creator:
|
45 |
+
name: Your Name
|
46 |
+
email: [email protected]
|
47 |
+
website: https://yourwebsite.com
|
48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
extensions_built_in/advanced_generator/Img2ImgGenerator.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import OrderedDict
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from diffusers import T2IAdapter
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
16 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
17 |
+
from toolkit.sampler import get_sampler
|
18 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
19 |
+
import gc
|
20 |
+
import torch
|
21 |
+
from jobs.process import BaseExtensionProcess
|
22 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
23 |
+
from toolkit.train_tools import get_torch_dtype
|
24 |
+
from controlnet_aux.midas import MidasDetector
|
25 |
+
from diffusers.utils import load_image
|
26 |
+
from torchvision.transforms import ToTensor
|
27 |
+
|
28 |
+
|
29 |
+
def flush():
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
gc.collect()
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class GenerateConfig:
|
38 |
+
|
39 |
+
def __init__(self, **kwargs):
|
40 |
+
self.prompts: List[str]
|
41 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
42 |
+
self.neg = kwargs.get('neg', '')
|
43 |
+
self.seed = kwargs.get('seed', -1)
|
44 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
45 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
46 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
47 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
48 |
+
self.ext = kwargs.get('ext', 'png')
|
49 |
+
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
|
50 |
+
self.trigger_word = kwargs.get('trigger_word', None)
|
51 |
+
|
52 |
+
|
53 |
+
class Img2ImgGenerator(BaseExtensionProcess):
|
54 |
+
|
55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
56 |
+
super().__init__(process_id, job, config)
|
57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
58 |
+
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
|
59 |
+
self.device = self.get_conf('device', 'cuda')
|
60 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
61 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
62 |
+
self.is_latents_cached = True
|
63 |
+
raw_datasets = self.get_conf('datasets', None)
|
64 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
65 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
66 |
+
self.datasets = None
|
67 |
+
self.datasets_reg = None
|
68 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
69 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
70 |
+
self.params = []
|
71 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
72 |
+
for raw_dataset in raw_datasets:
|
73 |
+
dataset = DatasetConfig(**raw_dataset)
|
74 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
75 |
+
if not is_caching:
|
76 |
+
self.is_latents_cached = False
|
77 |
+
if dataset.is_reg:
|
78 |
+
if self.datasets_reg is None:
|
79 |
+
self.datasets_reg = []
|
80 |
+
self.datasets_reg.append(dataset)
|
81 |
+
else:
|
82 |
+
if self.datasets is None:
|
83 |
+
self.datasets = []
|
84 |
+
self.datasets.append(dataset)
|
85 |
+
|
86 |
+
self.progress_bar = None
|
87 |
+
self.sd = StableDiffusion(
|
88 |
+
device=self.device,
|
89 |
+
model_config=self.model_config,
|
90 |
+
dtype=self.dtype,
|
91 |
+
)
|
92 |
+
print(f"Using device {self.device}")
|
93 |
+
self.data_loader: DataLoader = None
|
94 |
+
self.adapter: T2IAdapter = None
|
95 |
+
|
96 |
+
def to_pil(self, img):
|
97 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
98 |
+
img = (img + 1) / 2
|
99 |
+
img = img.clamp(0, 1)
|
100 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
101 |
+
img = (img * 255).astype(np.uint8)
|
102 |
+
image = Image.fromarray(img)
|
103 |
+
return image
|
104 |
+
|
105 |
+
def run(self):
|
106 |
+
with torch.no_grad():
|
107 |
+
super().run()
|
108 |
+
print("Loading model...")
|
109 |
+
self.sd.load_model()
|
110 |
+
device = torch.device(self.device)
|
111 |
+
|
112 |
+
if self.model_config.is_xl:
|
113 |
+
pipe = StableDiffusionXLImg2ImgPipeline(
|
114 |
+
vae=self.sd.vae,
|
115 |
+
unet=self.sd.unet,
|
116 |
+
text_encoder=self.sd.text_encoder[0],
|
117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
118 |
+
tokenizer=self.sd.tokenizer[0],
|
119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
121 |
+
).to(device, dtype=self.torch_dtype)
|
122 |
+
elif self.model_config.is_pixart:
|
123 |
+
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError("Only XL models are supported")
|
126 |
+
pipe.set_progress_bar_config(disable=True)
|
127 |
+
|
128 |
+
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
129 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
130 |
+
|
131 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
132 |
+
|
133 |
+
num_batches = len(self.data_loader)
|
134 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
135 |
+
seed = self.generate_config.seed
|
136 |
+
# load images from datasets, use tqdm
|
137 |
+
for i, batch in enumerate(self.data_loader):
|
138 |
+
batch: DataLoaderBatchDTO = batch
|
139 |
+
|
140 |
+
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
|
141 |
+
generator = torch.manual_seed(gen_seed)
|
142 |
+
|
143 |
+
file_item: FileItemDTO = batch.file_items[0]
|
144 |
+
img_path = file_item.path
|
145 |
+
img_filename = os.path.basename(img_path)
|
146 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
147 |
+
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
|
148 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
149 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
150 |
+
|
151 |
+
if self.copy_inputs_to is not None:
|
152 |
+
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
|
153 |
+
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
|
154 |
+
else:
|
155 |
+
output_inputs_path = None
|
156 |
+
output_inputs_caption_path = None
|
157 |
+
|
158 |
+
caption = batch.get_caption_list()[0]
|
159 |
+
if self.generate_config.trigger_word is not None:
|
160 |
+
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
|
161 |
+
|
162 |
+
img: torch.Tensor = batch.tensor.clone()
|
163 |
+
image = self.to_pil(img)
|
164 |
+
|
165 |
+
# image.save(output_depth_path)
|
166 |
+
if self.model_config.is_pixart:
|
167 |
+
pipe: PixArtSigmaPipeline = pipe
|
168 |
+
|
169 |
+
# Encode the full image once
|
170 |
+
encoded_image = pipe.vae.encode(
|
171 |
+
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
|
172 |
+
if hasattr(encoded_image, "latent_dist"):
|
173 |
+
latents = encoded_image.latent_dist.sample(generator)
|
174 |
+
elif hasattr(encoded_image, "latents"):
|
175 |
+
latents = encoded_image.latents
|
176 |
+
else:
|
177 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
178 |
+
latents = pipe.vae.config.scaling_factor * latents
|
179 |
+
|
180 |
+
# latents = self.sd.encode_images(img)
|
181 |
+
|
182 |
+
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
|
183 |
+
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
|
184 |
+
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
|
185 |
+
# timestep = timestep.to(device, dtype=torch.int32)
|
186 |
+
# latent = latent.to(device, dtype=self.torch_dtype)
|
187 |
+
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
|
188 |
+
# latent = self.sd.add_noise(latent, noise, timestep)
|
189 |
+
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
|
190 |
+
batch_size = 1
|
191 |
+
num_images_per_prompt = 1
|
192 |
+
|
193 |
+
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
|
194 |
+
image.width // pipe.vae_scale_factor)
|
195 |
+
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
|
196 |
+
|
197 |
+
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
|
198 |
+
num_inference_steps = self.generate_config.sample_steps
|
199 |
+
strength = self.generate_config.denoise_strength
|
200 |
+
# Get timesteps
|
201 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
202 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
203 |
+
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
|
204 |
+
timesteps = pipe.scheduler.timesteps[t_start:]
|
205 |
+
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
206 |
+
latents = pipe.scheduler.add_noise(latents, noise, timestep)
|
207 |
+
|
208 |
+
gen_images = pipe.__call__(
|
209 |
+
prompt=caption,
|
210 |
+
negative_prompt=self.generate_config.neg,
|
211 |
+
latents=latents,
|
212 |
+
timesteps=timesteps,
|
213 |
+
width=image.width,
|
214 |
+
height=image.height,
|
215 |
+
num_inference_steps=num_inference_steps,
|
216 |
+
num_images_per_prompt=num_images_per_prompt,
|
217 |
+
guidance_scale=self.generate_config.guidance_scale,
|
218 |
+
# strength=self.generate_config.denoise_strength,
|
219 |
+
use_resolution_binning=False,
|
220 |
+
output_type="np"
|
221 |
+
).images[0]
|
222 |
+
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
|
223 |
+
gen_images = Image.fromarray(gen_images)
|
224 |
+
else:
|
225 |
+
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
226 |
+
|
227 |
+
gen_images = pipe.__call__(
|
228 |
+
prompt=caption,
|
229 |
+
negative_prompt=self.generate_config.neg,
|
230 |
+
image=image,
|
231 |
+
num_inference_steps=self.generate_config.sample_steps,
|
232 |
+
guidance_scale=self.generate_config.guidance_scale,
|
233 |
+
strength=self.generate_config.denoise_strength,
|
234 |
+
).images[0]
|
235 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
236 |
+
gen_images.save(output_path)
|
237 |
+
|
238 |
+
# save caption
|
239 |
+
with open(output_caption_path, 'w') as f:
|
240 |
+
f.write(caption)
|
241 |
+
|
242 |
+
if output_inputs_path is not None:
|
243 |
+
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
|
244 |
+
image.save(output_inputs_path)
|
245 |
+
with open(output_inputs_caption_path, 'w') as f:
|
246 |
+
f.write(caption)
|
247 |
+
|
248 |
+
pbar.update(1)
|
249 |
+
batch.cleanup()
|
250 |
+
|
251 |
+
pbar.close()
|
252 |
+
print("Done generating images")
|
253 |
+
# cleanup
|
254 |
+
del self.sd
|
255 |
+
gc.collect()
|
256 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/PureLoraGenerator.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
5 |
+
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
6 |
+
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
8 |
+
import gc
|
9 |
+
import torch
|
10 |
+
from jobs.process import BaseExtensionProcess
|
11 |
+
from toolkit.train_tools import get_torch_dtype
|
12 |
+
|
13 |
+
|
14 |
+
def flush():
|
15 |
+
torch.cuda.empty_cache()
|
16 |
+
gc.collect()
|
17 |
+
|
18 |
+
|
19 |
+
class PureLoraGenerator(BaseExtensionProcess):
|
20 |
+
|
21 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
22 |
+
super().__init__(process_id, job, config)
|
23 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
24 |
+
self.device = self.get_conf('device', 'cuda')
|
25 |
+
self.device_torch = torch.device(self.device)
|
26 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
27 |
+
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
28 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
29 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
30 |
+
lorm_config = self.get_conf('lorm', None)
|
31 |
+
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
32 |
+
|
33 |
+
self.device_state_preset = get_train_sd_device_state_preset(
|
34 |
+
device=torch.device(self.device),
|
35 |
+
)
|
36 |
+
|
37 |
+
self.progress_bar = None
|
38 |
+
self.sd = StableDiffusion(
|
39 |
+
device=self.device,
|
40 |
+
model_config=self.model_config,
|
41 |
+
dtype=self.dtype,
|
42 |
+
)
|
43 |
+
|
44 |
+
def run(self):
|
45 |
+
super().run()
|
46 |
+
print("Loading model...")
|
47 |
+
with torch.no_grad():
|
48 |
+
self.sd.load_model()
|
49 |
+
self.sd.unet.eval()
|
50 |
+
self.sd.unet.to(self.device_torch)
|
51 |
+
if isinstance(self.sd.text_encoder, list):
|
52 |
+
for te in self.sd.text_encoder:
|
53 |
+
te.eval()
|
54 |
+
te.to(self.device_torch)
|
55 |
+
else:
|
56 |
+
self.sd.text_encoder.eval()
|
57 |
+
self.sd.to(self.device_torch)
|
58 |
+
|
59 |
+
print(f"Converting to LoRM UNet")
|
60 |
+
# replace the unet with LoRMUnet
|
61 |
+
convert_diffusers_unet_to_lorm(
|
62 |
+
self.sd.unet,
|
63 |
+
config=self.lorm_config,
|
64 |
+
)
|
65 |
+
|
66 |
+
sample_folder = os.path.join(self.output_folder)
|
67 |
+
gen_img_config_list = []
|
68 |
+
|
69 |
+
sample_config = self.generate_config
|
70 |
+
start_seed = sample_config.seed
|
71 |
+
current_seed = start_seed
|
72 |
+
for i in range(len(sample_config.prompts)):
|
73 |
+
if sample_config.walk_seed:
|
74 |
+
current_seed = start_seed + i
|
75 |
+
|
76 |
+
filename = f"[time]_[count].{self.generate_config.ext}"
|
77 |
+
output_path = os.path.join(sample_folder, filename)
|
78 |
+
prompt = sample_config.prompts[i]
|
79 |
+
extra_args = {}
|
80 |
+
gen_img_config_list.append(GenerateImageConfig(
|
81 |
+
prompt=prompt, # it will autoparse the prompt
|
82 |
+
width=sample_config.width,
|
83 |
+
height=sample_config.height,
|
84 |
+
negative_prompt=sample_config.neg,
|
85 |
+
seed=current_seed,
|
86 |
+
guidance_scale=sample_config.guidance_scale,
|
87 |
+
guidance_rescale=sample_config.guidance_rescale,
|
88 |
+
num_inference_steps=sample_config.sample_steps,
|
89 |
+
network_multiplier=sample_config.network_multiplier,
|
90 |
+
output_path=output_path,
|
91 |
+
output_ext=sample_config.ext,
|
92 |
+
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
93 |
+
**extra_args
|
94 |
+
))
|
95 |
+
|
96 |
+
# send to be generated
|
97 |
+
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
98 |
+
print("Done generating images")
|
99 |
+
# cleanup
|
100 |
+
del self.sd
|
101 |
+
gc.collect()
|
102 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/ReferenceGenerator.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from diffusers import T2IAdapter
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
14 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
15 |
+
from toolkit.sampler import get_sampler
|
16 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
17 |
+
import gc
|
18 |
+
import torch
|
19 |
+
from jobs.process import BaseExtensionProcess
|
20 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
21 |
+
from toolkit.train_tools import get_torch_dtype
|
22 |
+
from controlnet_aux.midas import MidasDetector
|
23 |
+
from diffusers.utils import load_image
|
24 |
+
|
25 |
+
|
26 |
+
def flush():
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
gc.collect()
|
29 |
+
|
30 |
+
|
31 |
+
class GenerateConfig:
|
32 |
+
|
33 |
+
def __init__(self, **kwargs):
|
34 |
+
self.prompts: List[str]
|
35 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
36 |
+
self.neg = kwargs.get('neg', '')
|
37 |
+
self.seed = kwargs.get('seed', -1)
|
38 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
39 |
+
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
|
40 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
41 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
42 |
+
self.prompt_2 = kwargs.get('prompt_2', None)
|
43 |
+
self.neg_2 = kwargs.get('neg_2', None)
|
44 |
+
self.prompts = kwargs.get('prompts', None)
|
45 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
46 |
+
self.ext = kwargs.get('ext', 'png')
|
47 |
+
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
48 |
+
if kwargs.get('shuffle', False):
|
49 |
+
# shuffle the prompts
|
50 |
+
random.shuffle(self.prompts)
|
51 |
+
|
52 |
+
|
53 |
+
class ReferenceGenerator(BaseExtensionProcess):
|
54 |
+
|
55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
56 |
+
super().__init__(process_id, job, config)
|
57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
58 |
+
self.device = self.get_conf('device', 'cuda')
|
59 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
60 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
61 |
+
self.is_latents_cached = True
|
62 |
+
raw_datasets = self.get_conf('datasets', None)
|
63 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
64 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
65 |
+
self.datasets = None
|
66 |
+
self.datasets_reg = None
|
67 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
68 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
69 |
+
self.params = []
|
70 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
71 |
+
for raw_dataset in raw_datasets:
|
72 |
+
dataset = DatasetConfig(**raw_dataset)
|
73 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
74 |
+
if not is_caching:
|
75 |
+
self.is_latents_cached = False
|
76 |
+
if dataset.is_reg:
|
77 |
+
if self.datasets_reg is None:
|
78 |
+
self.datasets_reg = []
|
79 |
+
self.datasets_reg.append(dataset)
|
80 |
+
else:
|
81 |
+
if self.datasets is None:
|
82 |
+
self.datasets = []
|
83 |
+
self.datasets.append(dataset)
|
84 |
+
|
85 |
+
self.progress_bar = None
|
86 |
+
self.sd = StableDiffusion(
|
87 |
+
device=self.device,
|
88 |
+
model_config=self.model_config,
|
89 |
+
dtype=self.dtype,
|
90 |
+
)
|
91 |
+
print(f"Using device {self.device}")
|
92 |
+
self.data_loader: DataLoader = None
|
93 |
+
self.adapter: T2IAdapter = None
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
super().run()
|
97 |
+
print("Loading model...")
|
98 |
+
self.sd.load_model()
|
99 |
+
device = torch.device(self.device)
|
100 |
+
|
101 |
+
if self.generate_config.t2i_adapter_path is not None:
|
102 |
+
self.adapter = T2IAdapter.from_pretrained(
|
103 |
+
self.generate_config.t2i_adapter_path,
|
104 |
+
torch_dtype=self.torch_dtype,
|
105 |
+
varient="fp16"
|
106 |
+
).to(device)
|
107 |
+
|
108 |
+
midas_depth = MidasDetector.from_pretrained(
|
109 |
+
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
110 |
+
).to(device)
|
111 |
+
|
112 |
+
if self.model_config.is_xl:
|
113 |
+
pipe = StableDiffusionXLAdapterPipeline(
|
114 |
+
vae=self.sd.vae,
|
115 |
+
unet=self.sd.unet,
|
116 |
+
text_encoder=self.sd.text_encoder[0],
|
117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
118 |
+
tokenizer=self.sd.tokenizer[0],
|
119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
121 |
+
adapter=self.adapter,
|
122 |
+
).to(device, dtype=self.torch_dtype)
|
123 |
+
else:
|
124 |
+
pipe = StableDiffusionAdapterPipeline(
|
125 |
+
vae=self.sd.vae,
|
126 |
+
unet=self.sd.unet,
|
127 |
+
text_encoder=self.sd.text_encoder,
|
128 |
+
tokenizer=self.sd.tokenizer,
|
129 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
130 |
+
safety_checker=None,
|
131 |
+
feature_extractor=None,
|
132 |
+
requires_safety_checker=False,
|
133 |
+
adapter=self.adapter,
|
134 |
+
).to(device, dtype=self.torch_dtype)
|
135 |
+
pipe.set_progress_bar_config(disable=True)
|
136 |
+
|
137 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
138 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
139 |
+
|
140 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
141 |
+
|
142 |
+
num_batches = len(self.data_loader)
|
143 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
144 |
+
seed = self.generate_config.seed
|
145 |
+
# load images from datasets, use tqdm
|
146 |
+
for i, batch in enumerate(self.data_loader):
|
147 |
+
batch: DataLoaderBatchDTO = batch
|
148 |
+
|
149 |
+
file_item: FileItemDTO = batch.file_items[0]
|
150 |
+
img_path = file_item.path
|
151 |
+
img_filename = os.path.basename(img_path)
|
152 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
153 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
154 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
155 |
+
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
|
156 |
+
|
157 |
+
caption = batch.get_caption_list()[0]
|
158 |
+
|
159 |
+
img: torch.Tensor = batch.tensor.clone()
|
160 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
161 |
+
img = (img + 1) / 2
|
162 |
+
img = img.clamp(0, 1)
|
163 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
164 |
+
img = (img * 255).astype(np.uint8)
|
165 |
+
image = Image.fromarray(img)
|
166 |
+
|
167 |
+
width, height = image.size
|
168 |
+
min_res = min(width, height)
|
169 |
+
|
170 |
+
if self.generate_config.walk_seed:
|
171 |
+
seed = seed + 1
|
172 |
+
|
173 |
+
if self.generate_config.seed == -1:
|
174 |
+
# random
|
175 |
+
seed = random.randint(0, 1000000)
|
176 |
+
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
torch.cuda.manual_seed(seed)
|
179 |
+
|
180 |
+
# generate depth map
|
181 |
+
image = midas_depth(
|
182 |
+
image,
|
183 |
+
detect_resolution=min_res, # do 512 ?
|
184 |
+
image_resolution=min_res
|
185 |
+
)
|
186 |
+
|
187 |
+
# image.save(output_depth_path)
|
188 |
+
|
189 |
+
gen_images = pipe(
|
190 |
+
prompt=caption,
|
191 |
+
negative_prompt=self.generate_config.neg,
|
192 |
+
image=image,
|
193 |
+
num_inference_steps=self.generate_config.sample_steps,
|
194 |
+
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
195 |
+
guidance_scale=self.generate_config.guidance_scale,
|
196 |
+
).images[0]
|
197 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
198 |
+
gen_images.save(output_path)
|
199 |
+
|
200 |
+
# save caption
|
201 |
+
with open(output_caption_path, 'w') as f:
|
202 |
+
f.write(caption)
|
203 |
+
|
204 |
+
pbar.update(1)
|
205 |
+
batch.cleanup()
|
206 |
+
|
207 |
+
pbar.close()
|
208 |
+
print("Done generating images")
|
209 |
+
# cleanup
|
210 |
+
del self.sd
|
211 |
+
gc.collect()
|
212 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
6 |
+
class AdvancedReferenceGeneratorExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "reference_generator"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Reference Generator"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ReferenceGenerator import ReferenceGenerator
|
19 |
+
return ReferenceGenerator
|
20 |
+
|
21 |
+
|
22 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
23 |
+
class PureLoraGenerator(Extension):
|
24 |
+
# uid must be unique, it is how the extension is identified
|
25 |
+
uid = "pure_lora_generator"
|
26 |
+
|
27 |
+
# name is the name of the extension for printing
|
28 |
+
name = "Pure LoRA Generator"
|
29 |
+
|
30 |
+
# This is where your process class is loaded
|
31 |
+
# keep your imports in here so they don't slow down the rest of the program
|
32 |
+
@classmethod
|
33 |
+
def get_process(cls):
|
34 |
+
# import your process class here so it is only loaded when needed and return it
|
35 |
+
from .PureLoraGenerator import PureLoraGenerator
|
36 |
+
return PureLoraGenerator
|
37 |
+
|
38 |
+
|
39 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
40 |
+
class Img2ImgGeneratorExtension(Extension):
|
41 |
+
# uid must be unique, it is how the extension is identified
|
42 |
+
uid = "batch_img2img"
|
43 |
+
|
44 |
+
# name is the name of the extension for printing
|
45 |
+
name = "Img2ImgGeneratorExtension"
|
46 |
+
|
47 |
+
# This is where your process class is loaded
|
48 |
+
# keep your imports in here so they don't slow down the rest of the program
|
49 |
+
@classmethod
|
50 |
+
def get_process(cls):
|
51 |
+
# import your process class here so it is only loaded when needed and return it
|
52 |
+
from .Img2ImgGenerator import Img2ImgGenerator
|
53 |
+
return Img2ImgGenerator
|
54 |
+
|
55 |
+
|
56 |
+
AI_TOOLKIT_EXTENSIONS = [
|
57 |
+
# you can put a list of extensions here
|
58 |
+
AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
|
59 |
+
]
|
extensions_built_in/advanced_generator/config/train.example.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: test_v1
|
5 |
+
process:
|
6 |
+
- type: 'textual_inversion_trainer'
|
7 |
+
training_folder: "out/TI"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "out/.tensorboard"
|
11 |
+
embedding:
|
12 |
+
trigger: "your_trigger_here"
|
13 |
+
tokens: 12
|
14 |
+
init_words: "man with short brown hair"
|
15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
16 |
+
save:
|
17 |
+
dtype: float16 # precision to save
|
18 |
+
save_every: 100 # save every this many steps
|
19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
20 |
+
datasets:
|
21 |
+
- folder_path: "/path/to/dataset"
|
22 |
+
caption_ext: "txt"
|
23 |
+
default_caption: "[trigger]"
|
24 |
+
buckets: true
|
25 |
+
resolution: 512
|
26 |
+
train:
|
27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
28 |
+
steps: 3000
|
29 |
+
weight_jitter: 0.0
|
30 |
+
lr: 5e-5
|
31 |
+
train_unet: false
|
32 |
+
gradient_checkpointing: true
|
33 |
+
train_text_encoder: false
|
34 |
+
optimizer: "adamw"
|
35 |
+
# optimizer: "prodigy"
|
36 |
+
optimizer_params:
|
37 |
+
weight_decay: 1e-2
|
38 |
+
lr_scheduler: "constant"
|
39 |
+
max_denoising_steps: 1000
|
40 |
+
batch_size: 4
|
41 |
+
dtype: bf16
|
42 |
+
xformers: true
|
43 |
+
min_snr_gamma: 5.0
|
44 |
+
# skip_first_sample: true
|
45 |
+
noise_offset: 0.0 # not needed for this
|
46 |
+
model:
|
47 |
+
# objective reality v2
|
48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
49 |
+
is_v2: false # for v2 models
|
50 |
+
is_xl: false # for SDXL models
|
51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
52 |
+
sample:
|
53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
54 |
+
sample_every: 100 # sample every this many steps
|
55 |
+
width: 512
|
56 |
+
height: 512
|
57 |
+
prompts:
|
58 |
+
- "photo of [trigger] laughing"
|
59 |
+
- "photo of [trigger] smiling"
|
60 |
+
- "[trigger] close up"
|
61 |
+
- "dark scene [trigger] frozen"
|
62 |
+
- "[trigger] nighttime"
|
63 |
+
- "a painting of [trigger]"
|
64 |
+
- "a drawing of [trigger]"
|
65 |
+
- "a cartoon of [trigger]"
|
66 |
+
- "[trigger] pixar style"
|
67 |
+
- "[trigger] costume"
|
68 |
+
neg: ""
|
69 |
+
seed: 42
|
70 |
+
walk_seed: false
|
71 |
+
guidance_scale: 7
|
72 |
+
sample_steps: 20
|
73 |
+
network_multiplier: 1.0
|
74 |
+
|
75 |
+
logging:
|
76 |
+
log_every: 10 # log every this many steps
|
77 |
+
use_wandb: false # not supported yet
|
78 |
+
verbose: false
|
79 |
+
|
80 |
+
# You can put any information you want here, and it will be saved in the model.
|
81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
82 |
+
# It is saved in the model so be aware of that. The software will include this
|
83 |
+
# plus some other information for you automatically
|
84 |
+
meta:
|
85 |
+
# [name] gets replaced with the name above
|
86 |
+
name: "[name]"
|
87 |
+
# version: '1.0'
|
88 |
+
# creator:
|
89 |
+
# name: Your Name
|
90 |
+
# email: [email protected]
|
91 |
+
# website: https://your.website
|
extensions_built_in/concept_replacer/ConceptReplacer.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import OrderedDict
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
5 |
+
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
6 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
7 |
+
import gc
|
8 |
+
import torch
|
9 |
+
from jobs.process import BaseSDTrainProcess
|
10 |
+
|
11 |
+
|
12 |
+
def flush():
|
13 |
+
torch.cuda.empty_cache()
|
14 |
+
gc.collect()
|
15 |
+
|
16 |
+
|
17 |
+
class ConceptReplacementConfig:
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
self.concept: str = kwargs.get('concept', '')
|
20 |
+
self.replacement: str = kwargs.get('replacement', '')
|
21 |
+
|
22 |
+
|
23 |
+
class ConceptReplacer(BaseSDTrainProcess):
|
24 |
+
|
25 |
+
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
26 |
+
super().__init__(process_id, job, config, **kwargs)
|
27 |
+
replacement_list = self.config.get('replacements', [])
|
28 |
+
self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
|
29 |
+
|
30 |
+
def before_model_load(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def hook_before_train_loop(self):
|
34 |
+
self.sd.vae.eval()
|
35 |
+
self.sd.vae.to(self.device_torch)
|
36 |
+
|
37 |
+
# textual inversion
|
38 |
+
if self.embedding is not None:
|
39 |
+
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
40 |
+
self.sd.text_encoder.train()
|
41 |
+
|
42 |
+
def hook_train_loop(self, batch):
|
43 |
+
with torch.no_grad():
|
44 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
45 |
+
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
46 |
+
network_weight_list = batch.get_network_weight_list()
|
47 |
+
|
48 |
+
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
49 |
+
if self.network is not None:
|
50 |
+
network = self.network
|
51 |
+
else:
|
52 |
+
network = BlankNetwork()
|
53 |
+
|
54 |
+
batch_replacement_list = []
|
55 |
+
# get a random replacement for each prompt
|
56 |
+
for prompt in conditioned_prompts:
|
57 |
+
replacement = random.choice(self.replacement_list)
|
58 |
+
batch_replacement_list.append(replacement)
|
59 |
+
|
60 |
+
# build out prompts
|
61 |
+
concept_prompts = []
|
62 |
+
replacement_prompts = []
|
63 |
+
for idx, replacement in enumerate(batch_replacement_list):
|
64 |
+
prompt = conditioned_prompts[idx]
|
65 |
+
|
66 |
+
# insert shuffled concept at beginning and end of prompt
|
67 |
+
shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
|
68 |
+
random.shuffle(shuffled_concept)
|
69 |
+
shuffled_concept = ', '.join(shuffled_concept)
|
70 |
+
concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
|
71 |
+
|
72 |
+
# insert replacement at beginning and end of prompt
|
73 |
+
shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
|
74 |
+
random.shuffle(shuffled_replacement)
|
75 |
+
shuffled_replacement = ', '.join(shuffled_replacement)
|
76 |
+
replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
|
77 |
+
|
78 |
+
# predict the replacement without network
|
79 |
+
conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
|
80 |
+
|
81 |
+
replacement_pred = self.sd.predict_noise(
|
82 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
83 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
84 |
+
timestep=timesteps,
|
85 |
+
guidance_scale=1.0,
|
86 |
+
)
|
87 |
+
|
88 |
+
del conditional_embeds
|
89 |
+
replacement_pred = replacement_pred.detach()
|
90 |
+
|
91 |
+
self.optimizer.zero_grad()
|
92 |
+
flush()
|
93 |
+
|
94 |
+
# text encoding
|
95 |
+
grad_on_text_encoder = False
|
96 |
+
if self.train_config.train_text_encoder:
|
97 |
+
grad_on_text_encoder = True
|
98 |
+
|
99 |
+
if self.embedding:
|
100 |
+
grad_on_text_encoder = True
|
101 |
+
|
102 |
+
# set the weights
|
103 |
+
network.multiplier = network_weight_list
|
104 |
+
|
105 |
+
# activate network if it exits
|
106 |
+
with network:
|
107 |
+
with torch.set_grad_enabled(grad_on_text_encoder):
|
108 |
+
# embed the prompts
|
109 |
+
conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
|
110 |
+
if not grad_on_text_encoder:
|
111 |
+
# detach the embeddings
|
112 |
+
conditional_embeds = conditional_embeds.detach()
|
113 |
+
self.optimizer.zero_grad()
|
114 |
+
flush()
|
115 |
+
|
116 |
+
noise_pred = self.sd.predict_noise(
|
117 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
118 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
119 |
+
timestep=timesteps,
|
120 |
+
guidance_scale=1.0,
|
121 |
+
)
|
122 |
+
|
123 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
124 |
+
loss = loss.mean([1, 2, 3])
|
125 |
+
|
126 |
+
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
127 |
+
# add min_snr_gamma
|
128 |
+
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
129 |
+
|
130 |
+
loss = loss.mean()
|
131 |
+
|
132 |
+
# back propagate loss to free ram
|
133 |
+
loss.backward()
|
134 |
+
flush()
|
135 |
+
|
136 |
+
# apply gradients
|
137 |
+
self.optimizer.step()
|
138 |
+
self.optimizer.zero_grad()
|
139 |
+
self.lr_scheduler.step()
|
140 |
+
|
141 |
+
if self.embedding is not None:
|
142 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
143 |
+
self.embedding.restore_embeddings()
|
144 |
+
|
145 |
+
loss_dict = OrderedDict(
|
146 |
+
{'loss': loss.item()}
|
147 |
+
)
|
148 |
+
# reset network multiplier
|
149 |
+
network.multiplier = 1.0
|
150 |
+
|
151 |
+
return loss_dict
|
extensions_built_in/concept_replacer/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
6 |
+
class ConceptReplacerExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "concept_replacer"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Concept Replacer"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ConceptReplacer import ConceptReplacer
|
19 |
+
return ConceptReplacer
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
AI_TOOLKIT_EXTENSIONS = [
|
24 |
+
# you can put a list of extensions here
|
25 |
+
ConceptReplacerExtension,
|
26 |
+
]
|
extensions_built_in/concept_replacer/config/train.example.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: test_v1
|
5 |
+
process:
|
6 |
+
- type: 'textual_inversion_trainer'
|
7 |
+
training_folder: "out/TI"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "out/.tensorboard"
|
11 |
+
embedding:
|
12 |
+
trigger: "your_trigger_here"
|
13 |
+
tokens: 12
|
14 |
+
init_words: "man with short brown hair"
|
15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
16 |
+
save:
|
17 |
+
dtype: float16 # precision to save
|
18 |
+
save_every: 100 # save every this many steps
|
19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
20 |
+
datasets:
|
21 |
+
- folder_path: "/path/to/dataset"
|
22 |
+
caption_ext: "txt"
|
23 |
+
default_caption: "[trigger]"
|
24 |
+
buckets: true
|
25 |
+
resolution: 512
|
26 |
+
train:
|
27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
28 |
+
steps: 3000
|
29 |
+
weight_jitter: 0.0
|
30 |
+
lr: 5e-5
|
31 |
+
train_unet: false
|
32 |
+
gradient_checkpointing: true
|
33 |
+
train_text_encoder: false
|
34 |
+
optimizer: "adamw"
|
35 |
+
# optimizer: "prodigy"
|
36 |
+
optimizer_params:
|
37 |
+
weight_decay: 1e-2
|
38 |
+
lr_scheduler: "constant"
|
39 |
+
max_denoising_steps: 1000
|
40 |
+
batch_size: 4
|
41 |
+
dtype: bf16
|
42 |
+
xformers: true
|
43 |
+
min_snr_gamma: 5.0
|
44 |
+
# skip_first_sample: true
|
45 |
+
noise_offset: 0.0 # not needed for this
|
46 |
+
model:
|
47 |
+
# objective reality v2
|
48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
49 |
+
is_v2: false # for v2 models
|
50 |
+
is_xl: false # for SDXL models
|
51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
52 |
+
sample:
|
53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
54 |
+
sample_every: 100 # sample every this many steps
|
55 |
+
width: 512
|
56 |
+
height: 512
|
57 |
+
prompts:
|
58 |
+
- "photo of [trigger] laughing"
|
59 |
+
- "photo of [trigger] smiling"
|
60 |
+
- "[trigger] close up"
|
61 |
+
- "dark scene [trigger] frozen"
|
62 |
+
- "[trigger] nighttime"
|
63 |
+
- "a painting of [trigger]"
|
64 |
+
- "a drawing of [trigger]"
|
65 |
+
- "a cartoon of [trigger]"
|
66 |
+
- "[trigger] pixar style"
|
67 |
+
- "[trigger] costume"
|
68 |
+
neg: ""
|
69 |
+
seed: 42
|
70 |
+
walk_seed: false
|
71 |
+
guidance_scale: 7
|
72 |
+
sample_steps: 20
|
73 |
+
network_multiplier: 1.0
|
74 |
+
|
75 |
+
logging:
|
76 |
+
log_every: 10 # log every this many steps
|
77 |
+
use_wandb: false # not supported yet
|
78 |
+
verbose: false
|
79 |
+
|
80 |
+
# You can put any information you want here, and it will be saved in the model.
|
81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
82 |
+
# It is saved in the model so be aware of that. The software will include this
|
83 |
+
# plus some other information for you automatically
|
84 |
+
meta:
|
85 |
+
# [name] gets replaced with the name above
|
86 |
+
name: "[name]"
|
87 |
+
# version: '1.0'
|
88 |
+
# creator:
|
89 |
+
# name: Your Name
|
90 |
+
# email: [email protected]
|
91 |
+
# website: https://your.website
|
extensions_built_in/dataset_tools/DatasetTools.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import gc
|
3 |
+
import torch
|
4 |
+
from jobs.process import BaseExtensionProcess
|
5 |
+
|
6 |
+
|
7 |
+
def flush():
|
8 |
+
torch.cuda.empty_cache()
|
9 |
+
gc.collect()
|
10 |
+
|
11 |
+
|
12 |
+
class DatasetTools(BaseExtensionProcess):
|
13 |
+
|
14 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
15 |
+
super().__init__(process_id, job, config)
|
16 |
+
|
17 |
+
def run(self):
|
18 |
+
super().run()
|
19 |
+
|
20 |
+
raise NotImplementedError("This extension is not yet implemented")
|
extensions_built_in/dataset_tools/SuperTagger.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections import OrderedDict
|
5 |
+
import gc
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
|
12 |
+
from .tools.fuyu_utils import FuyuImageProcessor
|
13 |
+
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
14 |
+
from .tools.llava_utils import LLaVAImageProcessor
|
15 |
+
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
|
16 |
+
from jobs.process import BaseExtensionProcess
|
17 |
+
from .tools.sync_tools import get_img_paths
|
18 |
+
|
19 |
+
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
20 |
+
|
21 |
+
|
22 |
+
def flush():
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
gc.collect()
|
25 |
+
|
26 |
+
|
27 |
+
VERSION = 2
|
28 |
+
|
29 |
+
|
30 |
+
class SuperTagger(BaseExtensionProcess):
|
31 |
+
|
32 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
33 |
+
super().__init__(process_id, job, config)
|
34 |
+
parent_dir = config.get('parent_dir', None)
|
35 |
+
self.dataset_paths: list[str] = config.get('dataset_paths', [])
|
36 |
+
self.device = config.get('device', 'cuda')
|
37 |
+
self.steps: list[Step] = config.get('steps', [])
|
38 |
+
self.caption_method = config.get('caption_method', 'llava:default')
|
39 |
+
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
40 |
+
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
41 |
+
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
42 |
+
self.caption_replacements = config.get('caption_replacements', default_replacements)
|
43 |
+
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
|
44 |
+
self.master_dataset_dict = OrderedDict()
|
45 |
+
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
46 |
+
if parent_dir is not None and len(self.dataset_paths) == 0:
|
47 |
+
# find all folders in the patent_dataset_path
|
48 |
+
self.dataset_paths = [
|
49 |
+
os.path.join(parent_dir, folder)
|
50 |
+
for folder in os.listdir(parent_dir)
|
51 |
+
if os.path.isdir(os.path.join(parent_dir, folder))
|
52 |
+
]
|
53 |
+
else:
|
54 |
+
# make sure they exist
|
55 |
+
for dataset_path in self.dataset_paths:
|
56 |
+
if not os.path.exists(dataset_path):
|
57 |
+
raise ValueError(f"Dataset path does not exist: {dataset_path}")
|
58 |
+
|
59 |
+
print(f"Found {len(self.dataset_paths)} dataset paths")
|
60 |
+
|
61 |
+
self.image_processor: ImageProcessor = self.get_image_processor()
|
62 |
+
|
63 |
+
def get_image_processor(self):
|
64 |
+
if self.caption_method.startswith('llava'):
|
65 |
+
return LLaVAImageProcessor(device=self.device)
|
66 |
+
elif self.caption_method.startswith('fuyu'):
|
67 |
+
return FuyuImageProcessor(device=self.device)
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Unknown caption method: {self.caption_method}")
|
70 |
+
|
71 |
+
def process_image(self, img_path: str):
|
72 |
+
root_img_dir = os.path.dirname(os.path.dirname(img_path))
|
73 |
+
filename = os.path.basename(img_path)
|
74 |
+
filename_no_ext = os.path.splitext(filename)[0]
|
75 |
+
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
|
76 |
+
train_img_path = os.path.join(train_dir, filename)
|
77 |
+
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
|
78 |
+
|
79 |
+
# check if json exists, if it does load it as image info
|
80 |
+
if os.path.exists(json_path):
|
81 |
+
with open(json_path, 'r') as f:
|
82 |
+
img_info = ImgInfo(**json.load(f))
|
83 |
+
else:
|
84 |
+
img_info = ImgInfo()
|
85 |
+
|
86 |
+
# always send steps first in case other processes need them
|
87 |
+
img_info.add_steps(copy.deepcopy(self.steps))
|
88 |
+
img_info.set_version(VERSION)
|
89 |
+
img_info.set_caption_method(self.caption_method)
|
90 |
+
|
91 |
+
image: Image = None
|
92 |
+
caption_image: Image = None
|
93 |
+
|
94 |
+
did_update_image = False
|
95 |
+
|
96 |
+
# trigger reprocess of steps
|
97 |
+
if self.force_reprocess_img:
|
98 |
+
img_info.trigger_image_reprocess()
|
99 |
+
|
100 |
+
# set the image as updated if it does not exist on disk
|
101 |
+
if not os.path.exists(train_img_path):
|
102 |
+
did_update_image = True
|
103 |
+
image = load_image(img_path)
|
104 |
+
if img_info.force_image_process:
|
105 |
+
did_update_image = True
|
106 |
+
image = load_image(img_path)
|
107 |
+
|
108 |
+
# go through the needed steps
|
109 |
+
for step in copy.deepcopy(img_info.state.steps_to_complete):
|
110 |
+
if step == 'caption':
|
111 |
+
# load image
|
112 |
+
if image is None:
|
113 |
+
image = load_image(img_path)
|
114 |
+
if caption_image is None:
|
115 |
+
caption_image = resize_to_max(image, 1024, 1024)
|
116 |
+
|
117 |
+
if not self.image_processor.is_loaded:
|
118 |
+
print('Loading Model. Takes a while, especially the first time')
|
119 |
+
self.image_processor.load_model()
|
120 |
+
|
121 |
+
img_info.caption = self.image_processor.generate_caption(
|
122 |
+
image=caption_image,
|
123 |
+
prompt=self.caption_prompt,
|
124 |
+
replacements=self.caption_replacements
|
125 |
+
)
|
126 |
+
img_info.mark_step_complete(step)
|
127 |
+
elif step == 'caption_short':
|
128 |
+
# load image
|
129 |
+
if image is None:
|
130 |
+
image = load_image(img_path)
|
131 |
+
|
132 |
+
if caption_image is None:
|
133 |
+
caption_image = resize_to_max(image, 1024, 1024)
|
134 |
+
|
135 |
+
if not self.image_processor.is_loaded:
|
136 |
+
print('Loading Model. Takes a while, especially the first time')
|
137 |
+
self.image_processor.load_model()
|
138 |
+
img_info.caption_short = self.image_processor.generate_caption(
|
139 |
+
image=caption_image,
|
140 |
+
prompt=self.caption_short_prompt,
|
141 |
+
replacements=self.caption_short_replacements
|
142 |
+
)
|
143 |
+
img_info.mark_step_complete(step)
|
144 |
+
elif step == 'contrast_stretch':
|
145 |
+
# load image
|
146 |
+
if image is None:
|
147 |
+
image = load_image(img_path)
|
148 |
+
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
|
149 |
+
did_update_image = True
|
150 |
+
img_info.mark_step_complete(step)
|
151 |
+
else:
|
152 |
+
raise ValueError(f"Unknown step: {step}")
|
153 |
+
|
154 |
+
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
|
155 |
+
if did_update_image:
|
156 |
+
image.save(train_img_path)
|
157 |
+
|
158 |
+
if img_info.is_dirty:
|
159 |
+
with open(json_path, 'w') as f:
|
160 |
+
json.dump(img_info.to_dict(), f, indent=4)
|
161 |
+
|
162 |
+
if self.dataset_master_config_file:
|
163 |
+
# add to master dict
|
164 |
+
self.master_dataset_dict[train_img_path] = img_info.to_dict()
|
165 |
+
|
166 |
+
def run(self):
|
167 |
+
super().run()
|
168 |
+
imgs_to_process = []
|
169 |
+
# find all images
|
170 |
+
for dataset_path in self.dataset_paths:
|
171 |
+
raw_dir = os.path.join(dataset_path, RAW_DIR)
|
172 |
+
raw_image_paths = get_img_paths(raw_dir)
|
173 |
+
for raw_image_path in raw_image_paths:
|
174 |
+
imgs_to_process.append(raw_image_path)
|
175 |
+
|
176 |
+
if len(imgs_to_process) == 0:
|
177 |
+
print(f"No images to process")
|
178 |
+
else:
|
179 |
+
print(f"Found {len(imgs_to_process)} to process")
|
180 |
+
|
181 |
+
for img_path in tqdm(imgs_to_process, desc="Processing images"):
|
182 |
+
try:
|
183 |
+
self.process_image(img_path)
|
184 |
+
except Exception:
|
185 |
+
# print full stack trace
|
186 |
+
print(traceback.format_exc())
|
187 |
+
continue
|
188 |
+
# self.process_image(img_path)
|
189 |
+
|
190 |
+
if self.dataset_master_config_file is not None:
|
191 |
+
# save it as json
|
192 |
+
with open(self.dataset_master_config_file, 'w') as f:
|
193 |
+
json.dump(self.master_dataset_dict, f, indent=4)
|
194 |
+
|
195 |
+
del self.image_processor
|
196 |
+
flush()
|
extensions_built_in/dataset_tools/SyncFromCollection.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from collections import OrderedDict
|
4 |
+
import gc
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
|
11 |
+
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
|
12 |
+
get_img_paths
|
13 |
+
from jobs.process import BaseExtensionProcess
|
14 |
+
|
15 |
+
|
16 |
+
def flush():
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
gc.collect()
|
19 |
+
|
20 |
+
|
21 |
+
class SyncFromCollection(BaseExtensionProcess):
|
22 |
+
|
23 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
24 |
+
super().__init__(process_id, job, config)
|
25 |
+
|
26 |
+
self.min_width = config.get('min_width', 1024)
|
27 |
+
self.min_height = config.get('min_height', 1024)
|
28 |
+
|
29 |
+
# add our min_width and min_height to each dataset config if they don't exist
|
30 |
+
for dataset_config in config.get('dataset_sync', []):
|
31 |
+
if 'min_width' not in dataset_config:
|
32 |
+
dataset_config['min_width'] = self.min_width
|
33 |
+
if 'min_height' not in dataset_config:
|
34 |
+
dataset_config['min_height'] = self.min_height
|
35 |
+
|
36 |
+
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
|
37 |
+
DatasetSyncCollectionConfig(**dataset_config)
|
38 |
+
for dataset_config in config.get('dataset_sync', [])
|
39 |
+
]
|
40 |
+
print(f"Found {len(self.dataset_configs)} dataset configs")
|
41 |
+
|
42 |
+
def move_new_images(self, root_dir: str):
|
43 |
+
raw_dir = os.path.join(root_dir, RAW_DIR)
|
44 |
+
new_dir = os.path.join(root_dir, NEW_DIR)
|
45 |
+
new_images = get_img_paths(new_dir)
|
46 |
+
|
47 |
+
for img_path in new_images:
|
48 |
+
# move to raw
|
49 |
+
new_path = os.path.join(raw_dir, os.path.basename(img_path))
|
50 |
+
shutil.move(img_path, new_path)
|
51 |
+
|
52 |
+
# remove new dir
|
53 |
+
shutil.rmtree(new_dir)
|
54 |
+
|
55 |
+
def sync_dataset(self, config: DatasetSyncCollectionConfig):
|
56 |
+
if config.host == 'unsplash':
|
57 |
+
get_images = get_unsplash_images
|
58 |
+
elif config.host == 'pexels':
|
59 |
+
get_images = get_pexels_images
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Unknown host: {config.host}")
|
62 |
+
|
63 |
+
results = {
|
64 |
+
'num_downloaded': 0,
|
65 |
+
'num_skipped': 0,
|
66 |
+
'bad': 0,
|
67 |
+
'total': 0,
|
68 |
+
}
|
69 |
+
|
70 |
+
photos = get_images(config)
|
71 |
+
raw_dir = os.path.join(config.directory, RAW_DIR)
|
72 |
+
new_dir = os.path.join(config.directory, NEW_DIR)
|
73 |
+
raw_images = get_local_image_file_names(raw_dir)
|
74 |
+
new_images = get_local_image_file_names(new_dir)
|
75 |
+
|
76 |
+
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
|
77 |
+
try:
|
78 |
+
if photo.filename not in raw_images and photo.filename not in new_images:
|
79 |
+
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
|
80 |
+
results['num_downloaded'] += 1
|
81 |
+
else:
|
82 |
+
results['num_skipped'] += 1
|
83 |
+
except Exception as e:
|
84 |
+
print(f" - BAD({photo.id}): {e}")
|
85 |
+
results['bad'] += 1
|
86 |
+
continue
|
87 |
+
results['total'] += 1
|
88 |
+
|
89 |
+
return results
|
90 |
+
|
91 |
+
def print_results(self, results):
|
92 |
+
print(
|
93 |
+
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
super().run()
|
97 |
+
print(f"Syncing {len(self.dataset_configs)} datasets")
|
98 |
+
all_results = None
|
99 |
+
failed_datasets = []
|
100 |
+
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
|
101 |
+
try:
|
102 |
+
results = self.sync_dataset(dataset_config)
|
103 |
+
if all_results is None:
|
104 |
+
all_results = {**results}
|
105 |
+
else:
|
106 |
+
for key, value in results.items():
|
107 |
+
all_results[key] += value
|
108 |
+
|
109 |
+
self.print_results(results)
|
110 |
+
except Exception as e:
|
111 |
+
print(f" - FAILED: {e}")
|
112 |
+
if 'response' in e.__dict__:
|
113 |
+
error = f"{e.response.status_code}: {e.response.text}"
|
114 |
+
print(f" - {error}")
|
115 |
+
failed_datasets.append({'dataset': dataset_config, 'error': error})
|
116 |
+
else:
|
117 |
+
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
|
118 |
+
continue
|
119 |
+
|
120 |
+
print("Moving new images to raw")
|
121 |
+
for dataset_config in self.dataset_configs:
|
122 |
+
self.move_new_images(dataset_config.directory)
|
123 |
+
|
124 |
+
print("Done syncing datasets")
|
125 |
+
self.print_results(all_results)
|
126 |
+
|
127 |
+
if len(failed_datasets) > 0:
|
128 |
+
print(f"Failed to sync {len(failed_datasets)} datasets")
|
129 |
+
for failed in failed_datasets:
|
130 |
+
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
|
131 |
+
print(f" - ERR: {failed['error']}")
|
extensions_built_in/dataset_tools/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from toolkit.extension import Extension
|
2 |
+
|
3 |
+
|
4 |
+
class DatasetToolsExtension(Extension):
|
5 |
+
uid = "dataset_tools"
|
6 |
+
|
7 |
+
# name is the name of the extension for printing
|
8 |
+
name = "Dataset Tools"
|
9 |
+
|
10 |
+
# This is where your process class is loaded
|
11 |
+
# keep your imports in here so they don't slow down the rest of the program
|
12 |
+
@classmethod
|
13 |
+
def get_process(cls):
|
14 |
+
# import your process class here so it is only loaded when needed and return it
|
15 |
+
from .DatasetTools import DatasetTools
|
16 |
+
return DatasetTools
|
17 |
+
|
18 |
+
|
19 |
+
class SyncFromCollectionExtension(Extension):
|
20 |
+
uid = "sync_from_collection"
|
21 |
+
name = "Sync from Collection"
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def get_process(cls):
|
25 |
+
# import your process class here so it is only loaded when needed and return it
|
26 |
+
from .SyncFromCollection import SyncFromCollection
|
27 |
+
return SyncFromCollection
|
28 |
+
|
29 |
+
|
30 |
+
class SuperTaggerExtension(Extension):
|
31 |
+
uid = "super_tagger"
|
32 |
+
name = "Super Tagger"
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def get_process(cls):
|
36 |
+
# import your process class here so it is only loaded when needed and return it
|
37 |
+
from .SuperTagger import SuperTagger
|
38 |
+
return SuperTagger
|
39 |
+
|
40 |
+
|
41 |
+
AI_TOOLKIT_EXTENSIONS = [
|
42 |
+
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
|
43 |
+
]
|
extensions_built_in/dataset_tools/tools/caption.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
caption_manipulation_steps = ['caption', 'caption_short']
|
3 |
+
|
4 |
+
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
5 |
+
default_short_prompt = 'caption this image in less than ten words'
|
6 |
+
|
7 |
+
default_replacements = [
|
8 |
+
("the image features", ""),
|
9 |
+
("the image shows", ""),
|
10 |
+
("the image depicts", ""),
|
11 |
+
("the image is", ""),
|
12 |
+
("in this image", ""),
|
13 |
+
("in the image", ""),
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def clean_caption(cap, replacements=None):
|
18 |
+
if replacements is None:
|
19 |
+
replacements = default_replacements
|
20 |
+
|
21 |
+
# remove any newlines
|
22 |
+
cap = cap.replace("\n", ", ")
|
23 |
+
cap = cap.replace("\r", ", ")
|
24 |
+
cap = cap.replace(".", ",")
|
25 |
+
cap = cap.replace("\"", "")
|
26 |
+
|
27 |
+
# remove unicode characters
|
28 |
+
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
29 |
+
|
30 |
+
# make lowercase
|
31 |
+
cap = cap.lower()
|
32 |
+
# remove any extra spaces
|
33 |
+
cap = " ".join(cap.split())
|
34 |
+
|
35 |
+
for replacement in replacements:
|
36 |
+
if replacement[0].startswith('*'):
|
37 |
+
# we are removing all text if it starts with this and the rest matches
|
38 |
+
search_text = replacement[0][1:]
|
39 |
+
if cap.startswith(search_text):
|
40 |
+
cap = ""
|
41 |
+
else:
|
42 |
+
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
|
43 |
+
|
44 |
+
cap_list = cap.split(",")
|
45 |
+
# trim whitespace
|
46 |
+
cap_list = [c.strip() for c in cap_list]
|
47 |
+
# remove empty strings
|
48 |
+
cap_list = [c for c in cap_list if c != ""]
|
49 |
+
# remove duplicates
|
50 |
+
cap_list = list(dict.fromkeys(cap_list))
|
51 |
+
# join back together
|
52 |
+
cap = ", ".join(cap_list)
|
53 |
+
return cap
|
extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Literal, Type, TYPE_CHECKING
|
3 |
+
|
4 |
+
Host: Type = Literal['unsplash', 'pexels']
|
5 |
+
|
6 |
+
RAW_DIR = "raw"
|
7 |
+
NEW_DIR = "_tmp"
|
8 |
+
TRAIN_DIR = "train"
|
9 |
+
DEPTH_DIR = "depth"
|
10 |
+
|
11 |
+
from .image_tools import Step, img_manipulation_steps
|
12 |
+
from .caption import caption_manipulation_steps
|
13 |
+
|
14 |
+
|
15 |
+
class DatasetSyncCollectionConfig:
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
self.host: Host = kwargs.get('host', None)
|
18 |
+
self.collection_id: str = kwargs.get('collection_id', None)
|
19 |
+
self.directory: str = kwargs.get('directory', None)
|
20 |
+
self.api_key: str = kwargs.get('api_key', None)
|
21 |
+
self.min_width: int = kwargs.get('min_width', 1024)
|
22 |
+
self.min_height: int = kwargs.get('min_height', 1024)
|
23 |
+
|
24 |
+
if self.host is None:
|
25 |
+
raise ValueError("host is required")
|
26 |
+
if self.collection_id is None:
|
27 |
+
raise ValueError("collection_id is required")
|
28 |
+
if self.directory is None:
|
29 |
+
raise ValueError("directory is required")
|
30 |
+
if self.api_key is None:
|
31 |
+
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
|
32 |
+
|
33 |
+
|
34 |
+
class ImageState:
|
35 |
+
def __init__(self, **kwargs):
|
36 |
+
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
|
37 |
+
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
|
38 |
+
|
39 |
+
def to_dict(self):
|
40 |
+
return {
|
41 |
+
'steps_complete': self.steps_complete
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
class Rect:
|
46 |
+
def __init__(self, **kwargs):
|
47 |
+
self.x = kwargs.get('x', 0)
|
48 |
+
self.y = kwargs.get('y', 0)
|
49 |
+
self.width = kwargs.get('width', 0)
|
50 |
+
self.height = kwargs.get('height', 0)
|
51 |
+
|
52 |
+
def to_dict(self):
|
53 |
+
return {
|
54 |
+
'x': self.x,
|
55 |
+
'y': self.y,
|
56 |
+
'width': self.width,
|
57 |
+
'height': self.height
|
58 |
+
}
|
59 |
+
|
60 |
+
|
61 |
+
class ImgInfo:
|
62 |
+
def __init__(self, **kwargs):
|
63 |
+
self.version: int = kwargs.get('version', None)
|
64 |
+
self.caption: str = kwargs.get('caption', None)
|
65 |
+
self.caption_short: str = kwargs.get('caption_short', None)
|
66 |
+
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
|
67 |
+
self.state = ImageState(**kwargs.get('state', {}))
|
68 |
+
self.caption_method = kwargs.get('caption_method', None)
|
69 |
+
self.other_captions = kwargs.get('other_captions', {})
|
70 |
+
self._upgrade_state()
|
71 |
+
self.force_image_process: bool = False
|
72 |
+
self._requested_steps: list[Step] = []
|
73 |
+
|
74 |
+
self.is_dirty: bool = False
|
75 |
+
|
76 |
+
def _upgrade_state(self):
|
77 |
+
# upgrades older states
|
78 |
+
if self.caption is not None and 'caption' not in self.state.steps_complete:
|
79 |
+
self.mark_step_complete('caption')
|
80 |
+
self.is_dirty = True
|
81 |
+
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
|
82 |
+
self.mark_step_complete('caption_short')
|
83 |
+
self.is_dirty = True
|
84 |
+
if self.caption_method is None and self.caption is not None:
|
85 |
+
# added caption method in version 2. Was all llava before that
|
86 |
+
self.caption_method = 'llava:default'
|
87 |
+
self.is_dirty = True
|
88 |
+
|
89 |
+
def to_dict(self):
|
90 |
+
return {
|
91 |
+
'version': self.version,
|
92 |
+
'caption_method': self.caption_method,
|
93 |
+
'caption': self.caption,
|
94 |
+
'caption_short': self.caption_short,
|
95 |
+
'poi': [poi.to_dict() for poi in self.poi],
|
96 |
+
'state': self.state.to_dict(),
|
97 |
+
'other_captions': self.other_captions
|
98 |
+
}
|
99 |
+
|
100 |
+
def mark_step_complete(self, step: Step):
|
101 |
+
if step not in self.state.steps_complete:
|
102 |
+
self.state.steps_complete.append(step)
|
103 |
+
if step in self.state.steps_to_complete:
|
104 |
+
self.state.steps_to_complete.remove(step)
|
105 |
+
self.is_dirty = True
|
106 |
+
|
107 |
+
def add_step(self, step: Step):
|
108 |
+
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
|
109 |
+
self.state.steps_to_complete.append(step)
|
110 |
+
|
111 |
+
def trigger_image_reprocess(self):
|
112 |
+
if self._requested_steps is None:
|
113 |
+
raise Exception("Must call add_steps before trigger_image_reprocess")
|
114 |
+
steps = self._requested_steps
|
115 |
+
# remove all image manipulationf from steps_to_complete
|
116 |
+
for step in img_manipulation_steps:
|
117 |
+
if step in self.state.steps_to_complete:
|
118 |
+
self.state.steps_to_complete.remove(step)
|
119 |
+
if step in self.state.steps_complete:
|
120 |
+
self.state.steps_complete.remove(step)
|
121 |
+
self.force_image_process = True
|
122 |
+
self.is_dirty = True
|
123 |
+
# we want to keep the order passed in process file
|
124 |
+
for step in steps:
|
125 |
+
if step in img_manipulation_steps:
|
126 |
+
self.add_step(step)
|
127 |
+
|
128 |
+
def add_steps(self, steps: list[Step]):
|
129 |
+
self._requested_steps = [step for step in steps]
|
130 |
+
for stage in steps:
|
131 |
+
self.add_step(stage)
|
132 |
+
|
133 |
+
# update steps if we have any img processes not complete, we have to reprocess them all
|
134 |
+
# if any steps_to_complete are in img_manipulation_steps
|
135 |
+
|
136 |
+
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
|
137 |
+
order_has_changed = False
|
138 |
+
|
139 |
+
if not is_manipulating_image:
|
140 |
+
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
|
141 |
+
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
|
142 |
+
current_img_manipulation_order = [step for step in self.state.steps_complete if
|
143 |
+
step in img_manipulation_steps]
|
144 |
+
if target_img_manipulation_order != current_img_manipulation_order:
|
145 |
+
order_has_changed = True
|
146 |
+
|
147 |
+
if is_manipulating_image or order_has_changed:
|
148 |
+
self.trigger_image_reprocess()
|
149 |
+
|
150 |
+
def set_caption_method(self, method: str):
|
151 |
+
if self._requested_steps is None:
|
152 |
+
raise Exception("Must call add_steps before set_caption_method")
|
153 |
+
if self.caption_method != method:
|
154 |
+
self.is_dirty = True
|
155 |
+
# move previous caption method to other_captions
|
156 |
+
if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
|
157 |
+
self.other_captions[self.caption_method] = {
|
158 |
+
'caption': self.caption,
|
159 |
+
'caption_short': self.caption_short,
|
160 |
+
}
|
161 |
+
self.caption_method = method
|
162 |
+
self.caption = None
|
163 |
+
self.caption_short = None
|
164 |
+
# see if we have a caption from the new method
|
165 |
+
if method in self.other_captions:
|
166 |
+
self.caption = self.other_captions[method].get('caption', None)
|
167 |
+
self.caption_short = self.other_captions[method].get('caption_short', None)
|
168 |
+
else:
|
169 |
+
self.trigger_new_caption()
|
170 |
+
|
171 |
+
def trigger_new_caption(self):
|
172 |
+
self.caption = None
|
173 |
+
self.caption_short = None
|
174 |
+
self.is_dirty = True
|
175 |
+
# check to see if we have any steps in the complete list and move them to the to_complete list
|
176 |
+
for step in self.state.steps_complete:
|
177 |
+
if step in caption_manipulation_steps:
|
178 |
+
self.state.steps_complete.remove(step)
|
179 |
+
self.state.steps_to_complete.append(step)
|
180 |
+
|
181 |
+
def to_json(self):
|
182 |
+
return json.dumps(self.to_dict())
|
183 |
+
|
184 |
+
def set_version(self, version: int):
|
185 |
+
if self.version != version:
|
186 |
+
self.is_dirty = True
|
187 |
+
self.version = version
|
extensions_built_in/dataset_tools/tools/fuyu_utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
|
2 |
+
|
3 |
+
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class FuyuImageProcessor:
|
9 |
+
def __init__(self, device='cuda'):
|
10 |
+
from transformers import FuyuProcessor, FuyuForCausalLM
|
11 |
+
self.device = device
|
12 |
+
self.model: FuyuForCausalLM = None
|
13 |
+
self.processor: FuyuProcessor = None
|
14 |
+
self.dtype = torch.bfloat16
|
15 |
+
self.tokenizer: AutoTokenizer
|
16 |
+
self.is_loaded = False
|
17 |
+
|
18 |
+
def load_model(self):
|
19 |
+
from transformers import FuyuProcessor, FuyuForCausalLM
|
20 |
+
model_path = "adept/fuyu-8b"
|
21 |
+
kwargs = {"device_map": self.device}
|
22 |
+
kwargs['load_in_4bit'] = True
|
23 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
24 |
+
load_in_4bit=True,
|
25 |
+
bnb_4bit_compute_dtype=self.dtype,
|
26 |
+
bnb_4bit_use_double_quant=True,
|
27 |
+
bnb_4bit_quant_type='nf4'
|
28 |
+
)
|
29 |
+
self.processor = FuyuProcessor.from_pretrained(model_path)
|
30 |
+
self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
31 |
+
self.is_loaded = True
|
32 |
+
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
34 |
+
self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
|
35 |
+
self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
|
36 |
+
|
37 |
+
def generate_caption(
|
38 |
+
self, image: Image,
|
39 |
+
prompt: str = default_long_prompt,
|
40 |
+
replacements=default_replacements,
|
41 |
+
max_new_tokens=512
|
42 |
+
):
|
43 |
+
# prepare inputs for the model
|
44 |
+
# text_prompt = f"{prompt}\n"
|
45 |
+
|
46 |
+
# image = image.convert('RGB')
|
47 |
+
model_inputs = self.processor(text=prompt, images=[image])
|
48 |
+
model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
|
49 |
+
model_inputs.items()}
|
50 |
+
|
51 |
+
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
|
52 |
+
prompt_len = model_inputs["input_ids"].shape[-1]
|
53 |
+
output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
|
54 |
+
output = clean_caption(output, replacements=replacements)
|
55 |
+
return output
|
56 |
+
|
57 |
+
# inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
|
58 |
+
# for k, v in inputs.items():
|
59 |
+
# inputs[k] = v.to(self.device)
|
60 |
+
|
61 |
+
# # autoregressively generate text
|
62 |
+
# generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
63 |
+
# generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
|
64 |
+
# output = generation_text[0]
|
65 |
+
#
|
66 |
+
# return clean_caption(output, replacements=replacements)
|
extensions_built_in/dataset_tools/tools/image_tools.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Type, TYPE_CHECKING, Union
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
|
7 |
+
Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
|
8 |
+
|
9 |
+
img_manipulation_steps = ['contrast_stretch']
|
10 |
+
|
11 |
+
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
12 |
+
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
from .llava_utils import LLaVAImageProcessor
|
15 |
+
from .fuyu_utils import FuyuImageProcessor
|
16 |
+
|
17 |
+
ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
|
18 |
+
|
19 |
+
|
20 |
+
def pil_to_cv2(image):
|
21 |
+
"""Convert a PIL image to a cv2 image."""
|
22 |
+
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
23 |
+
|
24 |
+
|
25 |
+
def cv2_to_pil(image):
|
26 |
+
"""Convert a cv2 image to a PIL image."""
|
27 |
+
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
28 |
+
|
29 |
+
|
30 |
+
def load_image(img_path: str):
|
31 |
+
image = Image.open(img_path).convert('RGB')
|
32 |
+
try:
|
33 |
+
# transpose with exif data
|
34 |
+
image = ImageOps.exif_transpose(image)
|
35 |
+
except Exception as e:
|
36 |
+
pass
|
37 |
+
return image
|
38 |
+
|
39 |
+
|
40 |
+
def resize_to_max(image, max_width=1024, max_height=1024):
|
41 |
+
width, height = image.size
|
42 |
+
if width <= max_width and height <= max_height:
|
43 |
+
return image
|
44 |
+
|
45 |
+
scale = min(max_width / width, max_height / height)
|
46 |
+
width = int(width * scale)
|
47 |
+
height = int(height * scale)
|
48 |
+
|
49 |
+
return image.resize((width, height), Image.LANCZOS)
|
extensions_built_in/dataset_tools/tools/llava_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
|
7 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
8 |
+
|
9 |
+
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
10 |
+
|
11 |
+
|
12 |
+
class LLaVAImageProcessor:
|
13 |
+
def __init__(self, device='cuda'):
|
14 |
+
try:
|
15 |
+
from llava.model import LlavaLlamaForCausalLM
|
16 |
+
except ImportError:
|
17 |
+
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
18 |
+
print(
|
19 |
+
"You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
20 |
+
raise
|
21 |
+
self.device = device
|
22 |
+
self.model: LlavaLlamaForCausalLM = None
|
23 |
+
self.tokenizer: AutoTokenizer = None
|
24 |
+
self.image_processor: CLIPImageProcessor = None
|
25 |
+
self.is_loaded = False
|
26 |
+
|
27 |
+
def load_model(self):
|
28 |
+
from llava.model import LlavaLlamaForCausalLM
|
29 |
+
|
30 |
+
model_path = "4bit/llava-v1.5-13b-3GB"
|
31 |
+
# kwargs = {"device_map": "auto"}
|
32 |
+
kwargs = {"device_map": self.device}
|
33 |
+
kwargs['load_in_4bit'] = True
|
34 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
35 |
+
load_in_4bit=True,
|
36 |
+
bnb_4bit_compute_dtype=torch.float16,
|
37 |
+
bnb_4bit_use_double_quant=True,
|
38 |
+
bnb_4bit_quant_type='nf4'
|
39 |
+
)
|
40 |
+
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
41 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
42 |
+
vision_tower = self.model.get_vision_tower()
|
43 |
+
if not vision_tower.is_loaded:
|
44 |
+
vision_tower.load_model()
|
45 |
+
vision_tower.to(device=self.device)
|
46 |
+
self.image_processor = vision_tower.image_processor
|
47 |
+
self.is_loaded = True
|
48 |
+
|
49 |
+
def generate_caption(
|
50 |
+
self, image:
|
51 |
+
Image, prompt: str = default_long_prompt,
|
52 |
+
replacements=default_replacements,
|
53 |
+
max_new_tokens=512
|
54 |
+
):
|
55 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
56 |
+
from llava.utils import disable_torch_init
|
57 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
58 |
+
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
|
59 |
+
# question = "how many dogs are in the picture?"
|
60 |
+
disable_torch_init()
|
61 |
+
conv_mode = "llava_v0"
|
62 |
+
conv = conv_templates[conv_mode].copy()
|
63 |
+
roles = conv.roles
|
64 |
+
image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
|
65 |
+
|
66 |
+
inp = f"{roles[0]}: {prompt}"
|
67 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
68 |
+
conv.append_message(conv.roles[0], inp)
|
69 |
+
conv.append_message(conv.roles[1], None)
|
70 |
+
raw_prompt = conv.get_prompt()
|
71 |
+
input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
|
72 |
+
return_tensors='pt').unsqueeze(0).cuda()
|
73 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
74 |
+
keywords = [stop_str]
|
75 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
76 |
+
with torch.inference_mode():
|
77 |
+
output_ids = self.model.generate(
|
78 |
+
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
79 |
+
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
80 |
+
top_p=0.8
|
81 |
+
)
|
82 |
+
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
83 |
+
conv.messages[-1][-1] = outputs
|
84 |
+
output = outputs.rsplit('</s>', 1)[0]
|
85 |
+
return clean_caption(output, replacements=replacements)
|
extensions_built_in/dataset_tools/tools/sync_tools.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import tqdm
|
4 |
+
from typing import List, Optional, TYPE_CHECKING
|
5 |
+
|
6 |
+
|
7 |
+
def img_root_path(img_id: str):
|
8 |
+
return os.path.dirname(os.path.dirname(img_id))
|
9 |
+
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
|
13 |
+
|
14 |
+
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
|
15 |
+
|
16 |
+
class Photo:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
id,
|
20 |
+
host,
|
21 |
+
width,
|
22 |
+
height,
|
23 |
+
url,
|
24 |
+
filename
|
25 |
+
):
|
26 |
+
self.id = str(id)
|
27 |
+
self.host = host
|
28 |
+
self.width = width
|
29 |
+
self.height = height
|
30 |
+
self.url = url
|
31 |
+
self.filename = filename
|
32 |
+
|
33 |
+
|
34 |
+
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
|
35 |
+
if img_width > img_height:
|
36 |
+
scale = min_height / img_height
|
37 |
+
else:
|
38 |
+
scale = min_width / img_width
|
39 |
+
|
40 |
+
new_width = int(img_width * scale)
|
41 |
+
new_height = int(img_height * scale)
|
42 |
+
|
43 |
+
return new_width, new_height
|
44 |
+
|
45 |
+
|
46 |
+
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
47 |
+
all_images = []
|
48 |
+
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
|
49 |
+
|
50 |
+
while True:
|
51 |
+
response = requests.get(next_page, headers={
|
52 |
+
"Authorization": f"{config.api_key}"
|
53 |
+
})
|
54 |
+
response.raise_for_status()
|
55 |
+
data = response.json()
|
56 |
+
all_images.extend(data['media'])
|
57 |
+
if 'next_page' in data and data['next_page']:
|
58 |
+
next_page = data['next_page']
|
59 |
+
else:
|
60 |
+
break
|
61 |
+
|
62 |
+
photos = []
|
63 |
+
for image in all_images:
|
64 |
+
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
65 |
+
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
|
66 |
+
filename = os.path.basename(image['src']['original'])
|
67 |
+
|
68 |
+
photos.append(Photo(
|
69 |
+
id=image['id'],
|
70 |
+
host="pexels",
|
71 |
+
width=image['width'],
|
72 |
+
height=image['height'],
|
73 |
+
url=url,
|
74 |
+
filename=filename
|
75 |
+
))
|
76 |
+
|
77 |
+
return photos
|
78 |
+
|
79 |
+
|
80 |
+
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
81 |
+
headers = {
|
82 |
+
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
|
83 |
+
"Authorization": f"Client-ID {config.api_key}"
|
84 |
+
}
|
85 |
+
# headers['Authorization'] = f"Bearer {token}"
|
86 |
+
|
87 |
+
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
|
88 |
+
response = requests.get(url, headers=headers)
|
89 |
+
response.raise_for_status()
|
90 |
+
res_headers = response.headers
|
91 |
+
# parse the link header to get the next page
|
92 |
+
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
|
93 |
+
has_next_page = False
|
94 |
+
if 'Link' in res_headers:
|
95 |
+
has_next_page = True
|
96 |
+
link_header = res_headers['Link']
|
97 |
+
link_header = link_header.split(',')
|
98 |
+
link_header = [link.strip() for link in link_header]
|
99 |
+
link_header = [link.split(';') for link in link_header]
|
100 |
+
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
|
101 |
+
link_header = {link[1]: link[0] for link in link_header}
|
102 |
+
|
103 |
+
# get page number from last url
|
104 |
+
last_page = link_header['rel="last']
|
105 |
+
last_page = last_page.split('?')[1]
|
106 |
+
last_page = last_page.split('&')
|
107 |
+
last_page = [param.split('=') for param in last_page]
|
108 |
+
last_page = {param[0]: param[1] for param in last_page}
|
109 |
+
last_page = int(last_page['page'])
|
110 |
+
|
111 |
+
all_images = response.json()
|
112 |
+
|
113 |
+
if has_next_page:
|
114 |
+
# assume we start on page 1, so we don't need to get it again
|
115 |
+
for page in tqdm.tqdm(range(2, last_page + 1)):
|
116 |
+
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
|
117 |
+
response = requests.get(url, headers=headers)
|
118 |
+
response.raise_for_status()
|
119 |
+
all_images.extend(response.json())
|
120 |
+
|
121 |
+
photos = []
|
122 |
+
for image in all_images:
|
123 |
+
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
124 |
+
url = f"{image['urls']['raw']}&w={new_width}"
|
125 |
+
filename = f"{image['id']}.jpg"
|
126 |
+
|
127 |
+
photos.append(Photo(
|
128 |
+
id=image['id'],
|
129 |
+
host="unsplash",
|
130 |
+
width=image['width'],
|
131 |
+
height=image['height'],
|
132 |
+
url=url,
|
133 |
+
filename=filename
|
134 |
+
))
|
135 |
+
|
136 |
+
return photos
|
137 |
+
|
138 |
+
|
139 |
+
def get_img_paths(dir_path: str):
|
140 |
+
os.makedirs(dir_path, exist_ok=True)
|
141 |
+
local_files = os.listdir(dir_path)
|
142 |
+
# remove non image files
|
143 |
+
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
|
144 |
+
# make full path
|
145 |
+
local_files = [os.path.join(dir_path, file) for file in local_files]
|
146 |
+
return local_files
|
147 |
+
|
148 |
+
|
149 |
+
def get_local_image_ids(dir_path: str):
|
150 |
+
os.makedirs(dir_path, exist_ok=True)
|
151 |
+
local_files = get_img_paths(dir_path)
|
152 |
+
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
153 |
+
return set([os.path.basename(file).split('.')[0] for file in local_files])
|
154 |
+
|
155 |
+
|
156 |
+
def get_local_image_file_names(dir_path: str):
|
157 |
+
os.makedirs(dir_path, exist_ok=True)
|
158 |
+
local_files = get_img_paths(dir_path)
|
159 |
+
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
160 |
+
return set([os.path.basename(file) for file in local_files])
|
161 |
+
|
162 |
+
|
163 |
+
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
|
164 |
+
img_width = photo.width
|
165 |
+
img_height = photo.height
|
166 |
+
|
167 |
+
if img_width < min_width or img_height < min_height:
|
168 |
+
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
|
169 |
+
|
170 |
+
img_response = requests.get(photo.url)
|
171 |
+
img_response.raise_for_status()
|
172 |
+
os.makedirs(dir_path, exist_ok=True)
|
173 |
+
|
174 |
+
filename = os.path.join(dir_path, photo.filename)
|
175 |
+
with open(filename, 'wb') as file:
|
176 |
+
file.write(img_response.content)
|
177 |
+
|
178 |
+
|
179 |
+
def update_caption(img_path: str):
|
180 |
+
# if the caption is a txt file, convert it to a json file
|
181 |
+
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
182 |
+
# see if it exists
|
183 |
+
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
|
184 |
+
# todo add poi and what not
|
185 |
+
return # we have a json file
|
186 |
+
caption = ""
|
187 |
+
# see if txt file exists
|
188 |
+
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
|
189 |
+
# read it
|
190 |
+
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
|
191 |
+
caption = file.read()
|
192 |
+
# write json file
|
193 |
+
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
|
194 |
+
file.write(f'{{"caption": "{caption}"}}')
|
195 |
+
|
196 |
+
# delete txt file
|
197 |
+
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
|
198 |
+
|
199 |
+
|
200 |
+
# def equalize_img(img_path: str):
|
201 |
+
# input_path = img_path
|
202 |
+
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
|
203 |
+
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
204 |
+
# process_img(
|
205 |
+
# img_path=input_path,
|
206 |
+
# output_path=output_path,
|
207 |
+
# equalize=True,
|
208 |
+
# max_size=2056,
|
209 |
+
# white_balance=False,
|
210 |
+
# gamma_correction=False,
|
211 |
+
# strength=0.6,
|
212 |
+
# )
|
213 |
+
|
214 |
+
|
215 |
+
# def annotate_depth(img_path: str):
|
216 |
+
# # make fake args
|
217 |
+
# args = argparse.Namespace()
|
218 |
+
# args.annotator = "midas"
|
219 |
+
# args.res = 1024
|
220 |
+
#
|
221 |
+
# img = cv2.imread(img_path)
|
222 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
223 |
+
#
|
224 |
+
# output = annotate(img, args)
|
225 |
+
#
|
226 |
+
# output = output.astype('uint8')
|
227 |
+
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
228 |
+
#
|
229 |
+
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
230 |
+
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
|
231 |
+
#
|
232 |
+
# cv2.imwrite(output_path, output)
|
233 |
+
|
234 |
+
|
235 |
+
# def invert_depth(img_path: str):
|
236 |
+
# img = cv2.imread(img_path)
|
237 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
238 |
+
# # invert the colors
|
239 |
+
# img = cv2.bitwise_not(img)
|
240 |
+
#
|
241 |
+
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
242 |
+
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
|
243 |
+
# cv2.imwrite(output_path, img)
|
244 |
+
|
245 |
+
|
246 |
+
#
|
247 |
+
# # update our list of raw images
|
248 |
+
# raw_images = get_img_paths(raw_dir)
|
249 |
+
#
|
250 |
+
# # update raw captions
|
251 |
+
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
|
252 |
+
# update_caption(image_id)
|
253 |
+
#
|
254 |
+
# # equalize images
|
255 |
+
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
|
256 |
+
# if img_path not in eq_images:
|
257 |
+
# equalize_img(img_path)
|
258 |
+
#
|
259 |
+
# # update our list of eq images
|
260 |
+
# eq_images = get_img_paths(eq_dir)
|
261 |
+
# # update eq captions
|
262 |
+
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
|
263 |
+
# update_caption(image_id)
|
264 |
+
#
|
265 |
+
# # annotate depth
|
266 |
+
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
|
267 |
+
# depth_images = get_img_paths(depth_dir)
|
268 |
+
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
|
269 |
+
# if img_path not in depth_images:
|
270 |
+
# annotate_depth(img_path)
|
271 |
+
#
|
272 |
+
# depth_images = get_img_paths(depth_dir)
|
273 |
+
#
|
274 |
+
# # invert depth
|
275 |
+
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
|
276 |
+
# inv_depth_images = get_img_paths(inv_depth_dir)
|
277 |
+
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
|
278 |
+
# if img_path not in inv_depth_images:
|
279 |
+
# invert_depth(img_path)
|
extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
from collections import OrderedDict
|
4 |
+
import os
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from typing import Optional, Union, List
|
7 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
8 |
+
|
9 |
+
from toolkit.config_modules import ReferenceDatasetConfig
|
10 |
+
from toolkit.data_loader import PairedImageDataset
|
11 |
+
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
12 |
+
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
13 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
14 |
+
import gc
|
15 |
+
from toolkit import train_tools
|
16 |
+
import torch
|
17 |
+
from jobs.process import BaseSDTrainProcess
|
18 |
+
import random
|
19 |
+
from toolkit.basic import value_map
|
20 |
+
|
21 |
+
|
22 |
+
def flush():
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
gc.collect()
|
25 |
+
|
26 |
+
|
27 |
+
class ReferenceSliderConfig:
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
30 |
+
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
|
31 |
+
self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
32 |
+
|
33 |
+
|
34 |
+
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
35 |
+
sd: StableDiffusion
|
36 |
+
data_loader: DataLoader = None
|
37 |
+
|
38 |
+
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
39 |
+
super().__init__(process_id, job, config, **kwargs)
|
40 |
+
self.prompt_txt_list = None
|
41 |
+
self.step_num = 0
|
42 |
+
self.start_step = 0
|
43 |
+
self.device = self.get_conf('device', self.job.device)
|
44 |
+
self.device_torch = torch.device(self.device)
|
45 |
+
self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {}))
|
46 |
+
|
47 |
+
def load_datasets(self):
|
48 |
+
if self.data_loader is None:
|
49 |
+
print(f"Loading datasets")
|
50 |
+
datasets = []
|
51 |
+
for dataset in self.slider_config.datasets:
|
52 |
+
print(f" - Dataset: {dataset.pair_folder}")
|
53 |
+
config = {
|
54 |
+
'path': dataset.pair_folder,
|
55 |
+
'size': dataset.size,
|
56 |
+
'default_prompt': dataset.target_class,
|
57 |
+
'network_weight': dataset.network_weight,
|
58 |
+
'pos_weight': dataset.pos_weight,
|
59 |
+
'neg_weight': dataset.neg_weight,
|
60 |
+
'pos_folder': dataset.pos_folder,
|
61 |
+
'neg_folder': dataset.neg_folder,
|
62 |
+
}
|
63 |
+
image_dataset = PairedImageDataset(config)
|
64 |
+
datasets.append(image_dataset)
|
65 |
+
|
66 |
+
concatenated_dataset = ConcatDataset(datasets)
|
67 |
+
self.data_loader = DataLoader(
|
68 |
+
concatenated_dataset,
|
69 |
+
batch_size=self.train_config.batch_size,
|
70 |
+
shuffle=True,
|
71 |
+
num_workers=2
|
72 |
+
)
|
73 |
+
|
74 |
+
def before_model_load(self):
|
75 |
+
pass
|
76 |
+
|
77 |
+
def hook_before_train_loop(self):
|
78 |
+
self.sd.vae.eval()
|
79 |
+
self.sd.vae.to(self.device_torch)
|
80 |
+
self.load_datasets()
|
81 |
+
|
82 |
+
pass
|
83 |
+
|
84 |
+
def hook_train_loop(self, batch):
|
85 |
+
with torch.no_grad():
|
86 |
+
imgs, prompts, network_weights = batch
|
87 |
+
network_pos_weight, network_neg_weight = network_weights
|
88 |
+
|
89 |
+
if isinstance(network_pos_weight, torch.Tensor):
|
90 |
+
network_pos_weight = network_pos_weight.item()
|
91 |
+
if isinstance(network_neg_weight, torch.Tensor):
|
92 |
+
network_neg_weight = network_neg_weight.item()
|
93 |
+
|
94 |
+
# get an array of random floats between -weight_jitter and weight_jitter
|
95 |
+
loss_jitter_multiplier = 1.0
|
96 |
+
weight_jitter = self.slider_config.weight_jitter
|
97 |
+
if weight_jitter > 0.0:
|
98 |
+
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
99 |
+
orig_network_pos_weight = network_pos_weight
|
100 |
+
network_pos_weight += jitter_list
|
101 |
+
network_neg_weight += (jitter_list * -1.0)
|
102 |
+
# penalize the loss for its distance from network_pos_weight
|
103 |
+
# a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
|
104 |
+
# so the loss_jitter_multiplier needs to be 0.4
|
105 |
+
loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
|
106 |
+
|
107 |
+
|
108 |
+
# if items in network_weight list are tensors, convert them to floats
|
109 |
+
|
110 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
111 |
+
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
112 |
+
# split batched images in half so left is negative and right is positive
|
113 |
+
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
114 |
+
|
115 |
+
positive_latents = self.sd.encode_images(positive_images)
|
116 |
+
negative_latents = self.sd.encode_images(negative_images)
|
117 |
+
|
118 |
+
height = positive_images.shape[2]
|
119 |
+
width = positive_images.shape[3]
|
120 |
+
batch_size = positive_images.shape[0]
|
121 |
+
|
122 |
+
if self.train_config.gradient_checkpointing:
|
123 |
+
# may get disabled elsewhere
|
124 |
+
self.sd.unet.enable_gradient_checkpointing()
|
125 |
+
|
126 |
+
noise_scheduler = self.sd.noise_scheduler
|
127 |
+
optimizer = self.optimizer
|
128 |
+
lr_scheduler = self.lr_scheduler
|
129 |
+
|
130 |
+
self.sd.noise_scheduler.set_timesteps(
|
131 |
+
self.train_config.max_denoising_steps, device=self.device_torch
|
132 |
+
)
|
133 |
+
|
134 |
+
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
135 |
+
timesteps = timesteps.long()
|
136 |
+
|
137 |
+
# get noise
|
138 |
+
noise_positive = self.sd.get_latent_noise(
|
139 |
+
pixel_height=height,
|
140 |
+
pixel_width=width,
|
141 |
+
batch_size=batch_size,
|
142 |
+
noise_offset=self.train_config.noise_offset,
|
143 |
+
).to(self.device_torch, dtype=dtype)
|
144 |
+
|
145 |
+
noise_negative = noise_positive.clone()
|
146 |
+
|
147 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
148 |
+
# (this is the forward diffusion process)
|
149 |
+
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
|
150 |
+
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
|
151 |
+
|
152 |
+
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
153 |
+
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
154 |
+
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
155 |
+
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
|
156 |
+
|
157 |
+
self.optimizer.zero_grad()
|
158 |
+
noisy_latents.requires_grad = False
|
159 |
+
|
160 |
+
# if training text encoder enable grads, else do context of no grad
|
161 |
+
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
162 |
+
# fix issue with them being tuples sometimes
|
163 |
+
prompt_list = []
|
164 |
+
for prompt in prompts:
|
165 |
+
if isinstance(prompt, tuple):
|
166 |
+
prompt = prompt[0]
|
167 |
+
prompt_list.append(prompt)
|
168 |
+
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
|
169 |
+
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
170 |
+
|
171 |
+
# if self.model_config.is_xl:
|
172 |
+
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
173 |
+
# network_multiplier_list = network_multiplier
|
174 |
+
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
175 |
+
# noise_list = torch.chunk(noise, 2, dim=0)
|
176 |
+
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
177 |
+
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
178 |
+
# else:
|
179 |
+
network_multiplier_list = [network_multiplier]
|
180 |
+
noisy_latent_list = [noisy_latents]
|
181 |
+
noise_list = [noise]
|
182 |
+
timesteps_list = [timesteps]
|
183 |
+
conditional_embeds_list = [conditional_embeds]
|
184 |
+
|
185 |
+
losses = []
|
186 |
+
# allow to chunk it out to save vram
|
187 |
+
for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip(
|
188 |
+
network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list
|
189 |
+
):
|
190 |
+
with self.network:
|
191 |
+
assert self.network.is_active
|
192 |
+
|
193 |
+
self.network.multiplier = network_multiplier
|
194 |
+
|
195 |
+
noise_pred = self.sd.predict_noise(
|
196 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
197 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
198 |
+
timestep=timesteps,
|
199 |
+
)
|
200 |
+
noise = noise.to(self.device_torch, dtype=dtype)
|
201 |
+
|
202 |
+
if self.sd.prediction_type == 'v_prediction':
|
203 |
+
# v-parameterization training
|
204 |
+
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
205 |
+
else:
|
206 |
+
target = noise
|
207 |
+
|
208 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
209 |
+
loss = loss.mean([1, 2, 3])
|
210 |
+
|
211 |
+
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
212 |
+
# add min_snr_gamma
|
213 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
214 |
+
|
215 |
+
loss = loss.mean() * loss_jitter_multiplier
|
216 |
+
|
217 |
+
loss_float = loss.item()
|
218 |
+
losses.append(loss_float)
|
219 |
+
|
220 |
+
# back propagate loss to free ram
|
221 |
+
loss.backward()
|
222 |
+
|
223 |
+
# apply gradients
|
224 |
+
optimizer.step()
|
225 |
+
lr_scheduler.step()
|
226 |
+
|
227 |
+
# reset network
|
228 |
+
self.network.multiplier = 1.0
|
229 |
+
|
230 |
+
loss_dict = OrderedDict(
|
231 |
+
{'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0}
|
232 |
+
)
|
233 |
+
|
234 |
+
return loss_dict
|
235 |
+
# end hook_train_loop
|
extensions_built_in/image_reference_slider_trainer/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# We make a subclass of Extension
|
6 |
+
class ImageReferenceSliderTrainer(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "image_reference_slider_trainer"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Image Reference Slider Trainer"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess
|
19 |
+
return ImageReferenceSliderTrainerProcess
|
20 |
+
|
21 |
+
|
22 |
+
AI_TOOLKIT_EXTENSIONS = [
|
23 |
+
# you can put a list of extensions here
|
24 |
+
ImageReferenceSliderTrainer
|
25 |
+
]
|
extensions_built_in/image_reference_slider_trainer/config/train.example.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: example_name
|
5 |
+
process:
|
6 |
+
- type: 'image_reference_slider_trainer'
|
7 |
+
training_folder: "/mnt/Train/out/LoRA"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "/home/jaret/Dev/.tensorboard"
|
11 |
+
network:
|
12 |
+
type: "lora"
|
13 |
+
linear: 8
|
14 |
+
linear_alpha: 8
|
15 |
+
train:
|
16 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
17 |
+
steps: 5000
|
18 |
+
lr: 1e-4
|
19 |
+
train_unet: true
|
20 |
+
gradient_checkpointing: true
|
21 |
+
train_text_encoder: true
|
22 |
+
optimizer: "adamw"
|
23 |
+
optimizer_params:
|
24 |
+
weight_decay: 1e-2
|
25 |
+
lr_scheduler: "constant"
|
26 |
+
max_denoising_steps: 1000
|
27 |
+
batch_size: 1
|
28 |
+
dtype: bf16
|
29 |
+
xformers: true
|
30 |
+
skip_first_sample: true
|
31 |
+
noise_offset: 0.0
|
32 |
+
model:
|
33 |
+
name_or_path: "/path/to/model.safetensors"
|
34 |
+
is_v2: false # for v2 models
|
35 |
+
is_xl: false # for SDXL models
|
36 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
37 |
+
save:
|
38 |
+
dtype: float16 # precision to save
|
39 |
+
save_every: 1000 # save every this many steps
|
40 |
+
max_step_saves_to_keep: 2 # only affects step counts
|
41 |
+
sample:
|
42 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
43 |
+
sample_every: 100 # sample every this many steps
|
44 |
+
width: 512
|
45 |
+
height: 512
|
46 |
+
prompts:
|
47 |
+
- "photo of a woman with red hair taking a selfie --m -3"
|
48 |
+
- "photo of a woman with red hair taking a selfie --m -1"
|
49 |
+
- "photo of a woman with red hair taking a selfie --m 1"
|
50 |
+
- "photo of a woman with red hair taking a selfie --m 3"
|
51 |
+
- "close up photo of a man smiling at the camera, in a tank top --m -3"
|
52 |
+
- "close up photo of a man smiling at the camera, in a tank top--m -1"
|
53 |
+
- "close up photo of a man smiling at the camera, in a tank top --m 1"
|
54 |
+
- "close up photo of a man smiling at the camera, in a tank top --m 3"
|
55 |
+
- "photo of a blonde woman smiling, barista --m -3"
|
56 |
+
- "photo of a blonde woman smiling, barista --m -1"
|
57 |
+
- "photo of a blonde woman smiling, barista --m 1"
|
58 |
+
- "photo of a blonde woman smiling, barista --m 3"
|
59 |
+
- "photo of a Christina Hendricks --m -1"
|
60 |
+
- "photo of a Christina Hendricks --m -1"
|
61 |
+
- "photo of a Christina Hendricks --m 1"
|
62 |
+
- "photo of a Christina Hendricks --m 3"
|
63 |
+
- "photo of a Christina Ricci --m -3"
|
64 |
+
- "photo of a Christina Ricci --m -1"
|
65 |
+
- "photo of a Christina Ricci --m 1"
|
66 |
+
- "photo of a Christina Ricci --m 3"
|
67 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
68 |
+
seed: 42
|
69 |
+
walk_seed: false
|
70 |
+
guidance_scale: 7
|
71 |
+
sample_steps: 20
|
72 |
+
network_multiplier: 1.0
|
73 |
+
|
74 |
+
logging:
|
75 |
+
log_every: 10 # log every this many steps
|
76 |
+
use_wandb: false # not supported yet
|
77 |
+
verbose: false
|
78 |
+
|
79 |
+
slider:
|
80 |
+
datasets:
|
81 |
+
- pair_folder: "/path/to/folder/side/by/side/images"
|
82 |
+
network_weight: 2.0
|
83 |
+
target_class: "" # only used as default if caption txt are not present
|
84 |
+
size: 512
|
85 |
+
- pair_folder: "/path/to/folder/side/by/side/images"
|
86 |
+
network_weight: 4.0
|
87 |
+
target_class: "" # only used as default if caption txt are not present
|
88 |
+
size: 512
|
89 |
+
|
90 |
+
|
91 |
+
# you can put any information you want here, and it will be saved in the model
|
92 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
93 |
+
# in the metadata. The software will include this plus some other information
|
94 |
+
meta:
|
95 |
+
name: "[name]" # [name] gets replaced with the name above
|
96 |
+
description: A short description of your model
|
97 |
+
trigger_words:
|
98 |
+
- put
|
99 |
+
- trigger
|
100 |
+
- words
|
101 |
+
- here
|
102 |
+
version: '0.1'
|
103 |
+
creator:
|
104 |
+
name: Your Name
|
105 |
+
email: [email protected]
|
106 |
+
website: https://yourwebsite.com
|
107 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
extensions_built_in/sd_trainer/SDTrainer.py
ADDED
@@ -0,0 +1,1679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import Union, Literal, List, Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
|
8 |
+
|
9 |
+
import torch.functional as F
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
12 |
+
|
13 |
+
from toolkit import train_tools
|
14 |
+
from toolkit.basic import value_map, adain, get_mean_std
|
15 |
+
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
16 |
+
from toolkit.config_modules import GuidanceConfig
|
17 |
+
from toolkit.data_loader import get_dataloader_datasets
|
18 |
+
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
19 |
+
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
|
20 |
+
from toolkit.image_utils import show_tensors, show_latents
|
21 |
+
from toolkit.ip_adapter import IPAdapter
|
22 |
+
from toolkit.custom_adapter import CustomAdapter
|
23 |
+
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
24 |
+
from toolkit.reference_adapter import ReferenceAdapter
|
25 |
+
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
26 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
27 |
+
apply_learnable_snr_gos, LearnableSNRGamma
|
28 |
+
import gc
|
29 |
+
import torch
|
30 |
+
from jobs.process import BaseSDTrainProcess
|
31 |
+
from torchvision import transforms
|
32 |
+
from diffusers import EMAModel
|
33 |
+
import math
|
34 |
+
from toolkit.train_tools import precondition_model_outputs_flow_match
|
35 |
+
|
36 |
+
|
37 |
+
def flush():
|
38 |
+
torch.cuda.empty_cache()
|
39 |
+
gc.collect()
|
40 |
+
|
41 |
+
|
42 |
+
adapter_transforms = transforms.Compose([
|
43 |
+
transforms.ToTensor(),
|
44 |
+
])
|
45 |
+
|
46 |
+
|
47 |
+
class SDTrainer(BaseSDTrainProcess):
|
48 |
+
|
49 |
+
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
50 |
+
super().__init__(process_id, job, config, **kwargs)
|
51 |
+
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
|
52 |
+
self.do_prior_prediction = False
|
53 |
+
self.do_long_prompts = False
|
54 |
+
self.do_guided_loss = False
|
55 |
+
self.taesd: Optional[AutoencoderTiny] = None
|
56 |
+
|
57 |
+
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
58 |
+
self.negative_prompt_pool: Union[List[str], None] = None
|
59 |
+
self.batch_negative_prompt: Union[List[str], None] = None
|
60 |
+
|
61 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
62 |
+
|
63 |
+
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
64 |
+
|
65 |
+
self.do_grad_scale = True
|
66 |
+
if self.is_fine_tuning and self.is_bfloat:
|
67 |
+
self.do_grad_scale = False
|
68 |
+
if self.adapter_config is not None:
|
69 |
+
if self.adapter_config.train:
|
70 |
+
self.do_grad_scale = False
|
71 |
+
|
72 |
+
if self.train_config.dtype in ["fp16", "float16"]:
|
73 |
+
# patch the scaler to allow fp16 training
|
74 |
+
org_unscale_grads = self.scaler._unscale_grads_
|
75 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
76 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
77 |
+
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
78 |
+
|
79 |
+
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
80 |
+
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
81 |
+
|
82 |
+
|
83 |
+
def before_model_load(self):
|
84 |
+
pass
|
85 |
+
|
86 |
+
def before_dataset_load(self):
|
87 |
+
self.assistant_adapter = None
|
88 |
+
# get adapter assistant if one is set
|
89 |
+
if self.train_config.adapter_assist_name_or_path is not None:
|
90 |
+
adapter_path = self.train_config.adapter_assist_name_or_path
|
91 |
+
|
92 |
+
if self.train_config.adapter_assist_type == "t2i":
|
93 |
+
# dont name this adapter since we are not training it
|
94 |
+
self.assistant_adapter = T2IAdapter.from_pretrained(
|
95 |
+
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
96 |
+
).to(self.device_torch)
|
97 |
+
elif self.train_config.adapter_assist_type == "control_net":
|
98 |
+
self.assistant_adapter = ControlNetModel.from_pretrained(
|
99 |
+
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
100 |
+
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
|
103 |
+
|
104 |
+
self.assistant_adapter.eval()
|
105 |
+
self.assistant_adapter.requires_grad_(False)
|
106 |
+
flush()
|
107 |
+
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
|
108 |
+
if self.model_config.is_xl:
|
109 |
+
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
|
110 |
+
torch_dtype=get_torch_dtype(self.train_config.dtype))
|
111 |
+
else:
|
112 |
+
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
|
113 |
+
torch_dtype=get_torch_dtype(self.train_config.dtype))
|
114 |
+
self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
|
115 |
+
self.taesd.eval()
|
116 |
+
self.taesd.requires_grad_(False)
|
117 |
+
|
118 |
+
def hook_before_train_loop(self):
|
119 |
+
super().hook_before_train_loop()
|
120 |
+
|
121 |
+
if self.train_config.do_prior_divergence:
|
122 |
+
self.do_prior_prediction = True
|
123 |
+
# move vae to device if we did not cache latents
|
124 |
+
if not self.is_latents_cached:
|
125 |
+
self.sd.vae.eval()
|
126 |
+
self.sd.vae.to(self.device_torch)
|
127 |
+
else:
|
128 |
+
# offload it. Already cached
|
129 |
+
self.sd.vae.to('cpu')
|
130 |
+
flush()
|
131 |
+
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
132 |
+
if self.adapter is not None:
|
133 |
+
self.adapter.to(self.device_torch)
|
134 |
+
|
135 |
+
# check if we have regs and using adapter and caching clip embeddings
|
136 |
+
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
|
137 |
+
is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
|
138 |
+
|
139 |
+
if has_reg and is_caching_clip_embeddings:
|
140 |
+
# we need a list of unconditional clip image embeds from other datasets to handle regs
|
141 |
+
unconditional_clip_image_embeds = []
|
142 |
+
datasets = get_dataloader_datasets(self.data_loader)
|
143 |
+
for i in range(len(datasets)):
|
144 |
+
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
|
145 |
+
|
146 |
+
if len(unconditional_clip_image_embeds) == 0:
|
147 |
+
raise ValueError("No unconditional clip image embeds found. This should not happen")
|
148 |
+
|
149 |
+
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
|
150 |
+
|
151 |
+
if self.train_config.negative_prompt is not None:
|
152 |
+
if os.path.exists(self.train_config.negative_prompt):
|
153 |
+
with open(self.train_config.negative_prompt, 'r') as f:
|
154 |
+
self.negative_prompt_pool = f.readlines()
|
155 |
+
# remove empty
|
156 |
+
self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""]
|
157 |
+
else:
|
158 |
+
# single prompt
|
159 |
+
self.negative_prompt_pool = [self.train_config.negative_prompt]
|
160 |
+
|
161 |
+
# handle unload text encoder
|
162 |
+
if self.train_config.unload_text_encoder:
|
163 |
+
with torch.no_grad():
|
164 |
+
if self.train_config.train_text_encoder:
|
165 |
+
raise ValueError("Cannot unload text encoder if training text encoder")
|
166 |
+
# cache embeddings
|
167 |
+
|
168 |
+
print("\n***** UNLOADING TEXT ENCODER *****")
|
169 |
+
print("This will train only with a blank prompt or trigger word, if set")
|
170 |
+
print("If this is not what you want, remove the unload_text_encoder flag")
|
171 |
+
print("***********************************")
|
172 |
+
print("")
|
173 |
+
self.sd.text_encoder_to(self.device_torch)
|
174 |
+
self.cached_blank_embeds = self.sd.encode_prompt("")
|
175 |
+
if self.trigger_word is not None:
|
176 |
+
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
177 |
+
|
178 |
+
# move back to cpu
|
179 |
+
self.sd.text_encoder_to('cpu')
|
180 |
+
flush()
|
181 |
+
|
182 |
+
|
183 |
+
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
184 |
+
# to process turbo learning, we make one big step from our current timestep to the end
|
185 |
+
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
186 |
+
# this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
|
187 |
+
# needs to be done on each item in batch as they may all have different timesteps
|
188 |
+
batch_size = pred.shape[0]
|
189 |
+
pred_chunks = torch.chunk(pred, batch_size, dim=0)
|
190 |
+
noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
|
191 |
+
timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
|
192 |
+
latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
|
193 |
+
noise_chunks = torch.chunk(noise, batch_size, dim=0)
|
194 |
+
|
195 |
+
with torch.no_grad():
|
196 |
+
# set the timesteps to 1000 so we can capture them to calculate the sigmas
|
197 |
+
self.sd.noise_scheduler.set_timesteps(
|
198 |
+
self.sd.noise_scheduler.config.num_train_timesteps,
|
199 |
+
device=self.device_torch
|
200 |
+
)
|
201 |
+
train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
|
202 |
+
|
203 |
+
train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
|
204 |
+
|
205 |
+
# set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
|
206 |
+
self.sd.noise_scheduler.set_timesteps(
|
207 |
+
1,
|
208 |
+
device=self.device_torch
|
209 |
+
)
|
210 |
+
|
211 |
+
denoised_pred_chunks = []
|
212 |
+
target_pred_chunks = []
|
213 |
+
|
214 |
+
for i in range(batch_size):
|
215 |
+
pred_item = pred_chunks[i]
|
216 |
+
noisy_latents_item = noisy_latents_chunks[i]
|
217 |
+
timesteps_item = timesteps_chunks[i]
|
218 |
+
latents_item = latent_chunks[i]
|
219 |
+
noise_item = noise_chunks[i]
|
220 |
+
with torch.no_grad():
|
221 |
+
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
|
222 |
+
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
|
223 |
+
# extract the sigma idx for our midpoint timestep
|
224 |
+
sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch)
|
225 |
+
|
226 |
+
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
|
227 |
+
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch)
|
228 |
+
|
229 |
+
# add noise to our target
|
230 |
+
|
231 |
+
# build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
|
232 |
+
# self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
|
233 |
+
self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
|
234 |
+
# set our single timstep
|
235 |
+
self.sd.noise_scheduler.timesteps = torch.from_numpy(
|
236 |
+
np.array(single_step_timestep_schedule, dtype=np.float32)
|
237 |
+
).to(device=self.device_torch)
|
238 |
+
|
239 |
+
# set the step index to None so it will be recalculated on first step
|
240 |
+
self.sd.noise_scheduler._step_index = None
|
241 |
+
|
242 |
+
denoised_latent = self.sd.noise_scheduler.step(
|
243 |
+
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
|
244 |
+
)[0]
|
245 |
+
|
246 |
+
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
|
247 |
+
self.train_config.dtype))
|
248 |
+
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
|
249 |
+
denoised_latent = denoised_latent - residual_noise
|
250 |
+
|
251 |
+
denoised_pred_chunks.append(denoised_latent)
|
252 |
+
|
253 |
+
denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
|
254 |
+
# set the scheduler back to the original timesteps
|
255 |
+
self.sd.noise_scheduler.set_timesteps(
|
256 |
+
self.sd.noise_scheduler.config.num_train_timesteps,
|
257 |
+
device=self.device_torch
|
258 |
+
)
|
259 |
+
|
260 |
+
output = denoised_latents / self.sd.vae.config['scaling_factor']
|
261 |
+
output = self.sd.vae.decode(output).sample
|
262 |
+
|
263 |
+
if self.train_config.show_turbo_outputs:
|
264 |
+
# since we are completely denoising, we can show them here
|
265 |
+
with torch.no_grad():
|
266 |
+
show_tensors(output)
|
267 |
+
|
268 |
+
# we return our big partial step denoised latents as our pred and our untouched latents as our target.
|
269 |
+
# you can do mse against the two here or run the denoised through the vae for pixel space loss against the
|
270 |
+
# input tensor images.
|
271 |
+
|
272 |
+
return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
273 |
+
|
274 |
+
# you can expand these in a child class to make customization easier
|
275 |
+
def calculate_loss(
|
276 |
+
self,
|
277 |
+
noise_pred: torch.Tensor,
|
278 |
+
noise: torch.Tensor,
|
279 |
+
noisy_latents: torch.Tensor,
|
280 |
+
timesteps: torch.Tensor,
|
281 |
+
batch: 'DataLoaderBatchDTO',
|
282 |
+
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
283 |
+
prior_pred: Union[torch.Tensor, None] = None,
|
284 |
+
**kwargs
|
285 |
+
):
|
286 |
+
loss_target = self.train_config.loss_target
|
287 |
+
is_reg = any(batch.get_is_reg_list())
|
288 |
+
|
289 |
+
prior_mask_multiplier = None
|
290 |
+
target_mask_multiplier = None
|
291 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
292 |
+
|
293 |
+
has_mask = batch.mask_tensor is not None
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
|
297 |
+
|
298 |
+
if self.train_config.match_noise_norm:
|
299 |
+
# match the norm of the noise
|
300 |
+
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
301 |
+
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
302 |
+
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
303 |
+
|
304 |
+
if self.train_config.pred_scaler != 1.0:
|
305 |
+
noise_pred = noise_pred * self.train_config.pred_scaler
|
306 |
+
|
307 |
+
target = None
|
308 |
+
|
309 |
+
if self.train_config.target_noise_multiplier != 1.0:
|
310 |
+
noise = noise * self.train_config.target_noise_multiplier
|
311 |
+
|
312 |
+
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
313 |
+
if self.train_config.correct_pred_norm and not is_reg:
|
314 |
+
with torch.no_grad():
|
315 |
+
# this only works if doing a prior pred
|
316 |
+
if prior_pred is not None:
|
317 |
+
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
318 |
+
prior_std = prior_pred.std([2,3], keepdim=True)
|
319 |
+
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
320 |
+
noise_std = noise_pred.std([2,3], keepdim=True)
|
321 |
+
|
322 |
+
mean_adjust = prior_mean - noise_mean
|
323 |
+
std_adjust = prior_std - noise_std
|
324 |
+
|
325 |
+
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
326 |
+
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
327 |
+
|
328 |
+
target_mean = noise_mean + mean_adjust
|
329 |
+
target_std = noise_std + std_adjust
|
330 |
+
|
331 |
+
eps = 1e-5
|
332 |
+
# match the noise to the prior
|
333 |
+
noise = (noise - noise_mean) / (noise_std + eps)
|
334 |
+
noise = noise * (target_std + eps) + target_mean
|
335 |
+
noise = noise.detach()
|
336 |
+
|
337 |
+
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
338 |
+
assert not self.train_config.train_turbo
|
339 |
+
with torch.no_grad():
|
340 |
+
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
341 |
+
stretched_mask_multiplier = value_map(
|
342 |
+
mask_multiplier,
|
343 |
+
batch.file_items[0].dataset_config.mask_min_value,
|
344 |
+
1.0,
|
345 |
+
0.0,
|
346 |
+
1.0
|
347 |
+
)
|
348 |
+
|
349 |
+
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
350 |
+
|
351 |
+
|
352 |
+
# target_mask_multiplier = mask_multiplier
|
353 |
+
# mask_multiplier = 1.0
|
354 |
+
target = noise
|
355 |
+
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
356 |
+
# set masked multiplier to 1.0 so we dont double apply it
|
357 |
+
# mask_multiplier = 1.0
|
358 |
+
elif prior_pred is not None and not self.train_config.do_prior_divergence:
|
359 |
+
assert not self.train_config.train_turbo
|
360 |
+
# matching adapter prediction
|
361 |
+
target = prior_pred
|
362 |
+
elif self.sd.prediction_type == 'v_prediction':
|
363 |
+
# v-parameterization training
|
364 |
+
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
365 |
+
|
366 |
+
elif self.sd.is_flow_matching:
|
367 |
+
target = (noise - batch.latents).detach()
|
368 |
+
else:
|
369 |
+
target = noise
|
370 |
+
|
371 |
+
if target is None:
|
372 |
+
target = noise
|
373 |
+
|
374 |
+
pred = noise_pred
|
375 |
+
|
376 |
+
if self.train_config.train_turbo:
|
377 |
+
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
|
378 |
+
|
379 |
+
ignore_snr = False
|
380 |
+
|
381 |
+
if loss_target == 'source' or loss_target == 'unaugmented':
|
382 |
+
assert not self.train_config.train_turbo
|
383 |
+
# ignore_snr = True
|
384 |
+
if batch.sigmas is None:
|
385 |
+
raise ValueError("Batch sigmas is None. This should not happen")
|
386 |
+
|
387 |
+
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
|
388 |
+
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
|
389 |
+
weighing = batch.sigmas ** -2.0
|
390 |
+
if loss_target == 'source':
|
391 |
+
# denoise the latent and compare to the latent in the batch
|
392 |
+
target = batch.latents
|
393 |
+
elif loss_target == 'unaugmented':
|
394 |
+
# we have to encode images into latents for now
|
395 |
+
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
396 |
+
with torch.no_grad():
|
397 |
+
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
|
398 |
+
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
399 |
+
target = unaugmented_latents.detach()
|
400 |
+
|
401 |
+
# Get the target for loss depending on the prediction type
|
402 |
+
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
403 |
+
target = target # we are computing loss against denoise latents
|
404 |
+
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
405 |
+
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
|
406 |
+
else:
|
407 |
+
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
408 |
+
|
409 |
+
# mse loss without reduction
|
410 |
+
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
411 |
+
loss = loss_per_element
|
412 |
+
else:
|
413 |
+
|
414 |
+
if self.train_config.loss_type == "mae":
|
415 |
+
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
416 |
+
else:
|
417 |
+
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
418 |
+
|
419 |
+
# handle linear timesteps and only adjust the weight of the timesteps
|
420 |
+
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
|
421 |
+
# calculate the weights for the timesteps
|
422 |
+
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
|
423 |
+
timesteps,
|
424 |
+
v2=self.train_config.linear_timesteps2
|
425 |
+
).to(loss.device, dtype=loss.dtype)
|
426 |
+
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
427 |
+
loss = loss * timestep_weight
|
428 |
+
|
429 |
+
if self.train_config.do_prior_divergence and prior_pred is not None:
|
430 |
+
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
|
431 |
+
|
432 |
+
if self.train_config.train_turbo:
|
433 |
+
mask_multiplier = mask_multiplier[:, 3:, :, :]
|
434 |
+
# resize to the size of the loss
|
435 |
+
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
|
436 |
+
|
437 |
+
# multiply by our mask
|
438 |
+
loss = loss * mask_multiplier
|
439 |
+
|
440 |
+
prior_loss = None
|
441 |
+
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
442 |
+
assert not self.train_config.train_turbo
|
443 |
+
if self.train_config.loss_type == "mae":
|
444 |
+
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
|
445 |
+
else:
|
446 |
+
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
|
447 |
+
|
448 |
+
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
449 |
+
if torch.isnan(prior_loss).any():
|
450 |
+
print("Prior loss is nan")
|
451 |
+
prior_loss = None
|
452 |
+
else:
|
453 |
+
prior_loss = prior_loss.mean([1, 2, 3])
|
454 |
+
# loss = loss + prior_loss
|
455 |
+
# loss = loss + prior_loss
|
456 |
+
# loss = loss + prior_loss
|
457 |
+
loss = loss.mean([1, 2, 3])
|
458 |
+
# apply loss multiplier before prior loss
|
459 |
+
loss = loss * loss_multiplier
|
460 |
+
if prior_loss is not None:
|
461 |
+
loss = loss + prior_loss
|
462 |
+
|
463 |
+
if not self.train_config.train_turbo:
|
464 |
+
if self.train_config.learnable_snr_gos:
|
465 |
+
# add snr_gamma
|
466 |
+
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
467 |
+
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
468 |
+
# add snr_gamma
|
469 |
+
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
|
470 |
+
fixed=True)
|
471 |
+
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
472 |
+
# add min_snr_gamma
|
473 |
+
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
474 |
+
|
475 |
+
loss = loss.mean()
|
476 |
+
|
477 |
+
# check for additional losses
|
478 |
+
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
|
479 |
+
|
480 |
+
loss = loss + self.adapter.additional_loss.mean()
|
481 |
+
self.adapter.additional_loss = None
|
482 |
+
|
483 |
+
if self.train_config.target_norm_std:
|
484 |
+
# seperate out the batch and channels
|
485 |
+
pred_std = noise_pred.std([2, 3], keepdim=True)
|
486 |
+
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
|
487 |
+
loss = loss + norm_std_loss
|
488 |
+
|
489 |
+
|
490 |
+
return loss
|
491 |
+
|
492 |
+
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
493 |
+
return batch
|
494 |
+
|
495 |
+
def get_guided_loss(
|
496 |
+
self,
|
497 |
+
noisy_latents: torch.Tensor,
|
498 |
+
conditional_embeds: PromptEmbeds,
|
499 |
+
match_adapter_assist: bool,
|
500 |
+
network_weight_list: list,
|
501 |
+
timesteps: torch.Tensor,
|
502 |
+
pred_kwargs: dict,
|
503 |
+
batch: 'DataLoaderBatchDTO',
|
504 |
+
noise: torch.Tensor,
|
505 |
+
unconditional_embeds: Optional[PromptEmbeds] = None,
|
506 |
+
**kwargs
|
507 |
+
):
|
508 |
+
loss = get_guidance_loss(
|
509 |
+
noisy_latents=noisy_latents,
|
510 |
+
conditional_embeds=conditional_embeds,
|
511 |
+
match_adapter_assist=match_adapter_assist,
|
512 |
+
network_weight_list=network_weight_list,
|
513 |
+
timesteps=timesteps,
|
514 |
+
pred_kwargs=pred_kwargs,
|
515 |
+
batch=batch,
|
516 |
+
noise=noise,
|
517 |
+
sd=self.sd,
|
518 |
+
unconditional_embeds=unconditional_embeds,
|
519 |
+
scaler=self.scaler,
|
520 |
+
**kwargs
|
521 |
+
)
|
522 |
+
|
523 |
+
return loss
|
524 |
+
|
525 |
+
def get_guided_loss_targeted_polarity(
|
526 |
+
self,
|
527 |
+
noisy_latents: torch.Tensor,
|
528 |
+
conditional_embeds: PromptEmbeds,
|
529 |
+
match_adapter_assist: bool,
|
530 |
+
network_weight_list: list,
|
531 |
+
timesteps: torch.Tensor,
|
532 |
+
pred_kwargs: dict,
|
533 |
+
batch: 'DataLoaderBatchDTO',
|
534 |
+
noise: torch.Tensor,
|
535 |
+
**kwargs
|
536 |
+
):
|
537 |
+
with torch.no_grad():
|
538 |
+
# Perform targeted guidance (working title)
|
539 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
540 |
+
|
541 |
+
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
542 |
+
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
543 |
+
|
544 |
+
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
545 |
+
|
546 |
+
unconditional_diff = (unconditional_latents - mean_latents)
|
547 |
+
conditional_diff = (conditional_latents - mean_latents)
|
548 |
+
|
549 |
+
# we need to determine the amount of signal and noise that would be present at the current timestep
|
550 |
+
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
551 |
+
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
552 |
+
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
553 |
+
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
554 |
+
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
555 |
+
|
556 |
+
# target_noise = noise + unconditional_signal
|
557 |
+
|
558 |
+
conditional_noisy_latents = self.sd.add_noise(
|
559 |
+
mean_latents,
|
560 |
+
noise,
|
561 |
+
timesteps
|
562 |
+
).detach()
|
563 |
+
|
564 |
+
unconditional_noisy_latents = self.sd.add_noise(
|
565 |
+
mean_latents,
|
566 |
+
noise,
|
567 |
+
timesteps
|
568 |
+
).detach()
|
569 |
+
|
570 |
+
# Disable the LoRA network so we can predict parent network knowledge without it
|
571 |
+
self.network.is_active = False
|
572 |
+
self.sd.unet.eval()
|
573 |
+
|
574 |
+
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
575 |
+
# This acts as our control to preserve the unaltered parts of the image.
|
576 |
+
baseline_prediction = self.sd.predict_noise(
|
577 |
+
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
578 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
579 |
+
timestep=timesteps,
|
580 |
+
guidance_scale=1.0,
|
581 |
+
**pred_kwargs # adapter residuals in here
|
582 |
+
).detach()
|
583 |
+
|
584 |
+
# double up everything to run it through all at once
|
585 |
+
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
586 |
+
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
|
587 |
+
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
588 |
+
|
589 |
+
# since we are dividing the polarity from the middle out, we need to double our network
|
590 |
+
# weights on training since the convergent point will be at half network strength
|
591 |
+
|
592 |
+
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
|
593 |
+
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
|
594 |
+
cat_network_weight_list = positive_network_weights + negative_network_weights
|
595 |
+
|
596 |
+
# turn the LoRA network back on.
|
597 |
+
self.sd.unet.train()
|
598 |
+
self.network.is_active = True
|
599 |
+
|
600 |
+
self.network.multiplier = cat_network_weight_list
|
601 |
+
|
602 |
+
# do our prediction with LoRA active on the scaled guidance latents
|
603 |
+
prediction = self.sd.predict_noise(
|
604 |
+
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
605 |
+
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
606 |
+
timestep=cat_timesteps,
|
607 |
+
guidance_scale=1.0,
|
608 |
+
**pred_kwargs # adapter residuals in here
|
609 |
+
)
|
610 |
+
|
611 |
+
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
612 |
+
|
613 |
+
pred_pos = pred_pos - baseline_prediction
|
614 |
+
pred_neg = pred_neg - baseline_prediction
|
615 |
+
|
616 |
+
pred_loss = torch.nn.functional.mse_loss(
|
617 |
+
pred_pos.float(),
|
618 |
+
unconditional_diff.float(),
|
619 |
+
reduction="none"
|
620 |
+
)
|
621 |
+
pred_loss = pred_loss.mean([1, 2, 3])
|
622 |
+
|
623 |
+
pred_neg_loss = torch.nn.functional.mse_loss(
|
624 |
+
pred_neg.float(),
|
625 |
+
conditional_diff.float(),
|
626 |
+
reduction="none"
|
627 |
+
)
|
628 |
+
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
629 |
+
|
630 |
+
loss = (pred_loss + pred_neg_loss) / 2.0
|
631 |
+
|
632 |
+
# loss = self.apply_snr(loss, timesteps)
|
633 |
+
loss = loss.mean()
|
634 |
+
loss.backward()
|
635 |
+
|
636 |
+
# detach it so parent class can run backward on no grads without throwing error
|
637 |
+
loss = loss.detach()
|
638 |
+
loss.requires_grad_(True)
|
639 |
+
|
640 |
+
return loss
|
641 |
+
|
642 |
+
def get_guided_loss_masked_polarity(
|
643 |
+
self,
|
644 |
+
noisy_latents: torch.Tensor,
|
645 |
+
conditional_embeds: PromptEmbeds,
|
646 |
+
match_adapter_assist: bool,
|
647 |
+
network_weight_list: list,
|
648 |
+
timesteps: torch.Tensor,
|
649 |
+
pred_kwargs: dict,
|
650 |
+
batch: 'DataLoaderBatchDTO',
|
651 |
+
noise: torch.Tensor,
|
652 |
+
**kwargs
|
653 |
+
):
|
654 |
+
with torch.no_grad():
|
655 |
+
# Perform targeted guidance (working title)
|
656 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
657 |
+
|
658 |
+
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
659 |
+
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
660 |
+
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
|
661 |
+
|
662 |
+
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
663 |
+
|
664 |
+
# unconditional_diff = (unconditional_latents - mean_latents)
|
665 |
+
# conditional_diff = (conditional_latents - mean_latents)
|
666 |
+
|
667 |
+
# we need to determine the amount of signal and noise that would be present at the current timestep
|
668 |
+
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
669 |
+
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
670 |
+
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
671 |
+
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
672 |
+
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
673 |
+
|
674 |
+
# make a differential mask
|
675 |
+
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
676 |
+
max_differential = \
|
677 |
+
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
678 |
+
differential_scaler = 1.0 / max_differential
|
679 |
+
differential_mask = differential_mask * differential_scaler
|
680 |
+
spread_point = 0.1
|
681 |
+
# adjust mask to amplify the differential at 0.1
|
682 |
+
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
|
683 |
+
# clip it
|
684 |
+
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
685 |
+
|
686 |
+
# target_noise = noise + unconditional_signal
|
687 |
+
|
688 |
+
conditional_noisy_latents = self.sd.add_noise(
|
689 |
+
conditional_latents,
|
690 |
+
noise,
|
691 |
+
timesteps
|
692 |
+
).detach()
|
693 |
+
|
694 |
+
unconditional_noisy_latents = self.sd.add_noise(
|
695 |
+
unconditional_latents,
|
696 |
+
noise,
|
697 |
+
timesteps
|
698 |
+
).detach()
|
699 |
+
|
700 |
+
inverse_noisy_latents = self.sd.add_noise(
|
701 |
+
inverse_latents,
|
702 |
+
noise,
|
703 |
+
timesteps
|
704 |
+
).detach()
|
705 |
+
|
706 |
+
# Disable the LoRA network so we can predict parent network knowledge without it
|
707 |
+
self.network.is_active = False
|
708 |
+
self.sd.unet.eval()
|
709 |
+
|
710 |
+
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
711 |
+
# This acts as our control to preserve the unaltered parts of the image.
|
712 |
+
# baseline_prediction = self.sd.predict_noise(
|
713 |
+
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
714 |
+
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
715 |
+
# timestep=timesteps,
|
716 |
+
# guidance_scale=1.0,
|
717 |
+
# **pred_kwargs # adapter residuals in here
|
718 |
+
# ).detach()
|
719 |
+
|
720 |
+
# double up everything to run it through all at once
|
721 |
+
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
722 |
+
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
723 |
+
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
724 |
+
|
725 |
+
# since we are dividing the polarity from the middle out, we need to double our network
|
726 |
+
# weights on training since the convergent point will be at half network strength
|
727 |
+
|
728 |
+
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
729 |
+
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
730 |
+
cat_network_weight_list = positive_network_weights + negative_network_weights
|
731 |
+
|
732 |
+
# turn the LoRA network back on.
|
733 |
+
self.sd.unet.train()
|
734 |
+
self.network.is_active = True
|
735 |
+
|
736 |
+
self.network.multiplier = cat_network_weight_list
|
737 |
+
|
738 |
+
# do our prediction with LoRA active on the scaled guidance latents
|
739 |
+
prediction = self.sd.predict_noise(
|
740 |
+
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
741 |
+
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
742 |
+
timestep=cat_timesteps,
|
743 |
+
guidance_scale=1.0,
|
744 |
+
**pred_kwargs # adapter residuals in here
|
745 |
+
)
|
746 |
+
|
747 |
+
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
748 |
+
|
749 |
+
# create a loss to balance the mean to 0 between the two predictions
|
750 |
+
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
|
751 |
+
|
752 |
+
# pred_pos = pred_pos - baseline_prediction
|
753 |
+
# pred_neg = pred_neg - baseline_prediction
|
754 |
+
|
755 |
+
pred_loss = torch.nn.functional.mse_loss(
|
756 |
+
pred_pos.float(),
|
757 |
+
noise.float(),
|
758 |
+
reduction="none"
|
759 |
+
)
|
760 |
+
# apply mask
|
761 |
+
pred_loss = pred_loss * (1.0 + differential_mask)
|
762 |
+
pred_loss = pred_loss.mean([1, 2, 3])
|
763 |
+
|
764 |
+
pred_neg_loss = torch.nn.functional.mse_loss(
|
765 |
+
pred_neg.float(),
|
766 |
+
noise.float(),
|
767 |
+
reduction="none"
|
768 |
+
)
|
769 |
+
# apply inverse mask
|
770 |
+
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
|
771 |
+
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
772 |
+
|
773 |
+
# make a loss to balance to losses of the pos and neg so they are equal
|
774 |
+
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
|
775 |
+
#
|
776 |
+
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
|
777 |
+
#
|
778 |
+
# # add a multiplier to balancing losses to make them the top priority
|
779 |
+
# differential_mean_loss = differential_mean_loss
|
780 |
+
|
781 |
+
# remove the grads from the negative as it is only a balancing loss
|
782 |
+
# pred_neg_loss = pred_neg_loss.detach()
|
783 |
+
|
784 |
+
# loss = pred_loss + pred_neg_loss + differential_mean_loss
|
785 |
+
loss = pred_loss + pred_neg_loss
|
786 |
+
|
787 |
+
# loss = self.apply_snr(loss, timesteps)
|
788 |
+
loss = loss.mean()
|
789 |
+
loss.backward()
|
790 |
+
|
791 |
+
# detach it so parent class can run backward on no grads without throwing error
|
792 |
+
loss = loss.detach()
|
793 |
+
loss.requires_grad_(True)
|
794 |
+
|
795 |
+
return loss
|
796 |
+
|
797 |
+
def get_prior_prediction(
|
798 |
+
self,
|
799 |
+
noisy_latents: torch.Tensor,
|
800 |
+
conditional_embeds: PromptEmbeds,
|
801 |
+
match_adapter_assist: bool,
|
802 |
+
network_weight_list: list,
|
803 |
+
timesteps: torch.Tensor,
|
804 |
+
pred_kwargs: dict,
|
805 |
+
batch: 'DataLoaderBatchDTO',
|
806 |
+
noise: torch.Tensor,
|
807 |
+
unconditional_embeds: Optional[PromptEmbeds] = None,
|
808 |
+
conditioned_prompts=None,
|
809 |
+
**kwargs
|
810 |
+
):
|
811 |
+
# todo for embeddings, we need to run without trigger words
|
812 |
+
was_unet_training = self.sd.unet.training
|
813 |
+
was_network_active = False
|
814 |
+
if self.network is not None:
|
815 |
+
was_network_active = self.network.is_active
|
816 |
+
self.network.is_active = False
|
817 |
+
can_disable_adapter = False
|
818 |
+
was_adapter_active = False
|
819 |
+
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
|
820 |
+
isinstance(self.adapter, ReferenceAdapter) or
|
821 |
+
(isinstance(self.adapter, CustomAdapter))
|
822 |
+
):
|
823 |
+
can_disable_adapter = True
|
824 |
+
was_adapter_active = self.adapter.is_active
|
825 |
+
self.adapter.is_active = False
|
826 |
+
|
827 |
+
if self.train_config.unload_text_encoder:
|
828 |
+
raise ValueError("Prior predictions currently do not support unloading text encoder")
|
829 |
+
# do a prediction here so we can match its output with network multiplier set to 0.0
|
830 |
+
with torch.no_grad():
|
831 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
832 |
+
|
833 |
+
embeds_to_use = conditional_embeds.clone().detach()
|
834 |
+
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
|
835 |
+
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
|
836 |
+
prompt_list = batch.get_caption_list()
|
837 |
+
class_name = ''
|
838 |
+
|
839 |
+
triggers = ['[trigger]', '[name]']
|
840 |
+
remove_tokens = []
|
841 |
+
|
842 |
+
if self.embed_config is not None:
|
843 |
+
triggers.append(self.embed_config.trigger)
|
844 |
+
for i in range(1, self.embed_config.tokens):
|
845 |
+
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
|
846 |
+
if self.embed_config.trigger_class_name is not None:
|
847 |
+
class_name = self.embed_config.trigger_class_name
|
848 |
+
|
849 |
+
if self.adapter is not None:
|
850 |
+
triggers.append(self.adapter_config.trigger)
|
851 |
+
for i in range(1, self.adapter_config.num_tokens):
|
852 |
+
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
|
853 |
+
if self.adapter_config.trigger_class_name is not None:
|
854 |
+
class_name = self.adapter_config.trigger_class_name
|
855 |
+
|
856 |
+
for idx, prompt in enumerate(prompt_list):
|
857 |
+
for remove_token in remove_tokens:
|
858 |
+
prompt = prompt.replace(remove_token, '')
|
859 |
+
for trigger in triggers:
|
860 |
+
prompt = prompt.replace(trigger, class_name)
|
861 |
+
prompt_list[idx] = prompt
|
862 |
+
|
863 |
+
embeds_to_use = self.sd.encode_prompt(
|
864 |
+
prompt_list,
|
865 |
+
long_prompts=self.do_long_prompts).to(
|
866 |
+
self.device_torch,
|
867 |
+
dtype=dtype).detach()
|
868 |
+
|
869 |
+
# dont use network on this
|
870 |
+
# self.network.multiplier = 0.0
|
871 |
+
self.sd.unet.eval()
|
872 |
+
|
873 |
+
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
|
874 |
+
# we need to remove the image embeds from the prompt except for flux
|
875 |
+
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
|
876 |
+
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
|
877 |
+
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
|
878 |
+
if unconditional_embeds is not None:
|
879 |
+
unconditional_embeds = unconditional_embeds.clone().detach()
|
880 |
+
unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
|
881 |
+
|
882 |
+
if unconditional_embeds is not None:
|
883 |
+
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
884 |
+
|
885 |
+
prior_pred = self.sd.predict_noise(
|
886 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
887 |
+
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
888 |
+
unconditional_embeddings=unconditional_embeds,
|
889 |
+
timestep=timesteps,
|
890 |
+
guidance_scale=self.train_config.cfg_scale,
|
891 |
+
rescale_cfg=self.train_config.cfg_rescale,
|
892 |
+
**pred_kwargs # adapter residuals in here
|
893 |
+
)
|
894 |
+
if was_unet_training:
|
895 |
+
self.sd.unet.train()
|
896 |
+
prior_pred = prior_pred.detach()
|
897 |
+
# remove the residuals as we wont use them on prediction when matching control
|
898 |
+
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
899 |
+
del pred_kwargs['down_intrablock_additional_residuals']
|
900 |
+
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
901 |
+
del pred_kwargs['down_block_additional_residuals']
|
902 |
+
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
|
903 |
+
del pred_kwargs['mid_block_additional_residual']
|
904 |
+
|
905 |
+
if can_disable_adapter:
|
906 |
+
self.adapter.is_active = was_adapter_active
|
907 |
+
# restore network
|
908 |
+
# self.network.multiplier = network_weight_list
|
909 |
+
if self.network is not None:
|
910 |
+
self.network.is_active = was_network_active
|
911 |
+
return prior_pred
|
912 |
+
|
913 |
+
def before_unet_predict(self):
|
914 |
+
pass
|
915 |
+
|
916 |
+
def after_unet_predict(self):
|
917 |
+
pass
|
918 |
+
|
919 |
+
def end_of_training_loop(self):
|
920 |
+
pass
|
921 |
+
|
922 |
+
def predict_noise(
|
923 |
+
self,
|
924 |
+
noisy_latents: torch.Tensor,
|
925 |
+
timesteps: Union[int, torch.Tensor] = 1,
|
926 |
+
conditional_embeds: Union[PromptEmbeds, None] = None,
|
927 |
+
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
928 |
+
**kwargs,
|
929 |
+
):
|
930 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
931 |
+
return self.sd.predict_noise(
|
932 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
933 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
934 |
+
unconditional_embeddings=unconditional_embeds,
|
935 |
+
timestep=timesteps,
|
936 |
+
guidance_scale=self.train_config.cfg_scale,
|
937 |
+
guidance_embedding_scale=self.train_config.cfg_scale,
|
938 |
+
detach_unconditional=False,
|
939 |
+
rescale_cfg=self.train_config.cfg_rescale,
|
940 |
+
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
941 |
+
**kwargs
|
942 |
+
)
|
943 |
+
|
944 |
+
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
945 |
+
self.timer.start('preprocess_batch')
|
946 |
+
batch = self.preprocess_batch(batch)
|
947 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
948 |
+
# sanity check
|
949 |
+
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
950 |
+
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
951 |
+
if isinstance(self.sd.text_encoder, list):
|
952 |
+
for encoder in self.sd.text_encoder:
|
953 |
+
if encoder.dtype != self.sd.te_torch_dtype:
|
954 |
+
encoder.to(self.sd.te_torch_dtype)
|
955 |
+
else:
|
956 |
+
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
957 |
+
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
958 |
+
|
959 |
+
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
960 |
+
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
961 |
+
# pick random negative prompts
|
962 |
+
if self.negative_prompt_pool is not None:
|
963 |
+
negative_prompts = []
|
964 |
+
for i in range(noisy_latents.shape[0]):
|
965 |
+
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
966 |
+
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
967 |
+
this_neg_prompt = ', '.join(this_neg_prompts)
|
968 |
+
negative_prompts.append(this_neg_prompt)
|
969 |
+
self.batch_negative_prompt = negative_prompts
|
970 |
+
else:
|
971 |
+
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
972 |
+
|
973 |
+
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
974 |
+
# condition the prompt
|
975 |
+
# todo handle more than one adapter image
|
976 |
+
self.adapter.num_control_images = 1
|
977 |
+
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
978 |
+
|
979 |
+
network_weight_list = batch.get_network_weight_list()
|
980 |
+
if self.train_config.single_item_batching:
|
981 |
+
network_weight_list = network_weight_list + network_weight_list
|
982 |
+
|
983 |
+
has_adapter_img = batch.control_tensor is not None
|
984 |
+
has_clip_image = batch.clip_image_tensor is not None
|
985 |
+
has_clip_image_embeds = batch.clip_image_embeds is not None
|
986 |
+
# force it to be true if doing regs as we handle those differently
|
987 |
+
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
988 |
+
has_clip_image = True
|
989 |
+
if self._clip_image_embeds_unconditional is not None:
|
990 |
+
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
991 |
+
has_clip_image = False
|
992 |
+
|
993 |
+
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
994 |
+
raise ValueError(
|
995 |
+
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
996 |
+
|
997 |
+
match_adapter_assist = False
|
998 |
+
|
999 |
+
# check if we are matching the adapter assistant
|
1000 |
+
if self.assistant_adapter:
|
1001 |
+
if self.train_config.match_adapter_chance == 1.0:
|
1002 |
+
match_adapter_assist = True
|
1003 |
+
elif self.train_config.match_adapter_chance > 0.0:
|
1004 |
+
match_adapter_assist = torch.rand(
|
1005 |
+
(1,), device=self.device_torch, dtype=dtype
|
1006 |
+
) < self.train_config.match_adapter_chance
|
1007 |
+
|
1008 |
+
self.timer.stop('preprocess_batch')
|
1009 |
+
|
1010 |
+
is_reg = False
|
1011 |
+
with torch.no_grad():
|
1012 |
+
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
1013 |
+
for idx, file_item in enumerate(batch.file_items):
|
1014 |
+
if file_item.is_reg:
|
1015 |
+
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
1016 |
+
is_reg = True
|
1017 |
+
|
1018 |
+
adapter_images = None
|
1019 |
+
sigmas = None
|
1020 |
+
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
1021 |
+
with self.timer('get_adapter_images'):
|
1022 |
+
# todo move this to data loader
|
1023 |
+
if batch.control_tensor is not None:
|
1024 |
+
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
1025 |
+
# match in channels
|
1026 |
+
if self.assistant_adapter is not None:
|
1027 |
+
in_channels = self.assistant_adapter.config.in_channels
|
1028 |
+
if adapter_images.shape[1] != in_channels:
|
1029 |
+
# we need to match the channels
|
1030 |
+
adapter_images = adapter_images[:, :in_channels, :, :]
|
1031 |
+
else:
|
1032 |
+
raise NotImplementedError("Adapter images now must be loaded with dataloader")
|
1033 |
+
|
1034 |
+
clip_images = None
|
1035 |
+
if has_clip_image:
|
1036 |
+
with self.timer('get_clip_images'):
|
1037 |
+
# todo move this to data loader
|
1038 |
+
if batch.clip_image_tensor is not None:
|
1039 |
+
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
|
1040 |
+
|
1041 |
+
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
1042 |
+
if batch.mask_tensor is not None:
|
1043 |
+
with self.timer('get_mask_multiplier'):
|
1044 |
+
# upsampling no supported for bfloat16
|
1045 |
+
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
1046 |
+
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
1047 |
+
mask_multiplier = torch.nn.functional.interpolate(
|
1048 |
+
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
1049 |
+
)
|
1050 |
+
# expand to match latents
|
1051 |
+
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
1052 |
+
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
1053 |
+
|
1054 |
+
def get_adapter_multiplier():
|
1055 |
+
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
1056 |
+
# training a t2i adapter, not using as assistant.
|
1057 |
+
return 1.0
|
1058 |
+
elif match_adapter_assist:
|
1059 |
+
# training a texture. We want it high
|
1060 |
+
adapter_strength_min = 0.9
|
1061 |
+
adapter_strength_max = 1.0
|
1062 |
+
else:
|
1063 |
+
# training with assistance, we want it low
|
1064 |
+
# adapter_strength_min = 0.4
|
1065 |
+
# adapter_strength_max = 0.7
|
1066 |
+
adapter_strength_min = 0.5
|
1067 |
+
adapter_strength_max = 1.1
|
1068 |
+
|
1069 |
+
adapter_conditioning_scale = torch.rand(
|
1070 |
+
(1,), device=self.device_torch, dtype=dtype
|
1071 |
+
)
|
1072 |
+
|
1073 |
+
adapter_conditioning_scale = value_map(
|
1074 |
+
adapter_conditioning_scale,
|
1075 |
+
0.0,
|
1076 |
+
1.0,
|
1077 |
+
adapter_strength_min,
|
1078 |
+
adapter_strength_max
|
1079 |
+
)
|
1080 |
+
return adapter_conditioning_scale
|
1081 |
+
|
1082 |
+
# flush()
|
1083 |
+
with self.timer('grad_setup'):
|
1084 |
+
|
1085 |
+
# text encoding
|
1086 |
+
grad_on_text_encoder = False
|
1087 |
+
if self.train_config.train_text_encoder:
|
1088 |
+
grad_on_text_encoder = True
|
1089 |
+
|
1090 |
+
if self.embedding is not None:
|
1091 |
+
grad_on_text_encoder = True
|
1092 |
+
|
1093 |
+
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
1094 |
+
grad_on_text_encoder = True
|
1095 |
+
|
1096 |
+
if self.adapter_config and self.adapter_config.type == 'te_augmenter':
|
1097 |
+
grad_on_text_encoder = True
|
1098 |
+
|
1099 |
+
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
1100 |
+
if self.network is not None:
|
1101 |
+
network = self.network
|
1102 |
+
else:
|
1103 |
+
network = BlankNetwork()
|
1104 |
+
|
1105 |
+
# set the weights
|
1106 |
+
network.multiplier = network_weight_list
|
1107 |
+
|
1108 |
+
# activate network if it exits
|
1109 |
+
|
1110 |
+
prompts_1 = conditioned_prompts
|
1111 |
+
prompts_2 = None
|
1112 |
+
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
|
1113 |
+
prompts_1 = batch.get_caption_short_list()
|
1114 |
+
prompts_2 = conditioned_prompts
|
1115 |
+
|
1116 |
+
# make the batch splits
|
1117 |
+
if self.train_config.single_item_batching:
|
1118 |
+
if self.model_config.refiner_name_or_path is not None:
|
1119 |
+
raise ValueError("Single item batching is not supported when training the refiner")
|
1120 |
+
batch_size = noisy_latents.shape[0]
|
1121 |
+
# chunk/split everything
|
1122 |
+
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
1123 |
+
noise_list = torch.chunk(noise, batch_size, dim=0)
|
1124 |
+
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
1125 |
+
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
|
1126 |
+
if imgs is not None:
|
1127 |
+
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
1128 |
+
else:
|
1129 |
+
imgs_list = [None for _ in range(batch_size)]
|
1130 |
+
if adapter_images is not None:
|
1131 |
+
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
|
1132 |
+
else:
|
1133 |
+
adapter_images_list = [None for _ in range(batch_size)]
|
1134 |
+
if clip_images is not None:
|
1135 |
+
clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
|
1136 |
+
else:
|
1137 |
+
clip_images_list = [None for _ in range(batch_size)]
|
1138 |
+
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
1139 |
+
if prompts_2 is None:
|
1140 |
+
prompt_2_list = [None for _ in range(batch_size)]
|
1141 |
+
else:
|
1142 |
+
prompt_2_list = [[prompt] for prompt in prompts_2]
|
1143 |
+
|
1144 |
+
else:
|
1145 |
+
noisy_latents_list = [noisy_latents]
|
1146 |
+
noise_list = [noise]
|
1147 |
+
timesteps_list = [timesteps]
|
1148 |
+
conditioned_prompts_list = [prompts_1]
|
1149 |
+
imgs_list = [imgs]
|
1150 |
+
adapter_images_list = [adapter_images]
|
1151 |
+
clip_images_list = [clip_images]
|
1152 |
+
mask_multiplier_list = [mask_multiplier]
|
1153 |
+
if prompts_2 is None:
|
1154 |
+
prompt_2_list = [None]
|
1155 |
+
else:
|
1156 |
+
prompt_2_list = [prompts_2]
|
1157 |
+
|
1158 |
+
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
|
1159 |
+
noisy_latents_list,
|
1160 |
+
noise_list,
|
1161 |
+
timesteps_list,
|
1162 |
+
conditioned_prompts_list,
|
1163 |
+
imgs_list,
|
1164 |
+
adapter_images_list,
|
1165 |
+
clip_images_list,
|
1166 |
+
mask_multiplier_list,
|
1167 |
+
prompt_2_list
|
1168 |
+
):
|
1169 |
+
|
1170 |
+
# if self.train_config.negative_prompt is not None:
|
1171 |
+
# # add negative prompt
|
1172 |
+
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
1173 |
+
# range(len(conditioned_prompts))]
|
1174 |
+
# if prompt_2 is not None:
|
1175 |
+
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
1176 |
+
|
1177 |
+
with (network):
|
1178 |
+
# encode clip adapter here so embeds are active for tokenizer
|
1179 |
+
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
1180 |
+
with self.timer('encode_clip_vision_embeds'):
|
1181 |
+
if has_clip_image:
|
1182 |
+
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1183 |
+
clip_images.detach().to(self.device_torch, dtype=dtype),
|
1184 |
+
is_training=True,
|
1185 |
+
has_been_preprocessed=True
|
1186 |
+
)
|
1187 |
+
else:
|
1188 |
+
# just do a blank one
|
1189 |
+
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1190 |
+
torch.zeros(
|
1191 |
+
(noisy_latents.shape[0], 3, 512, 512),
|
1192 |
+
device=self.device_torch, dtype=dtype
|
1193 |
+
),
|
1194 |
+
is_training=True,
|
1195 |
+
has_been_preprocessed=True,
|
1196 |
+
drop=True
|
1197 |
+
)
|
1198 |
+
# it will be injected into the tokenizer when called
|
1199 |
+
self.adapter(conditional_clip_embeds)
|
1200 |
+
|
1201 |
+
# do the custom adapter after the prior prediction
|
1202 |
+
if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg):
|
1203 |
+
quad_count = random.randint(1, 4)
|
1204 |
+
self.adapter.train()
|
1205 |
+
self.adapter.trigger_pre_te(
|
1206 |
+
tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise
|
1207 |
+
is_training=True,
|
1208 |
+
has_been_preprocessed=True,
|
1209 |
+
quad_count=quad_count,
|
1210 |
+
batch_size=noisy_latents.shape[0]
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
with self.timer('encode_prompt'):
|
1214 |
+
unconditional_embeds = None
|
1215 |
+
if self.train_config.unload_text_encoder:
|
1216 |
+
with torch.set_grad_enabled(False):
|
1217 |
+
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
|
1218 |
+
self.device_torch, dtype=dtype
|
1219 |
+
)
|
1220 |
+
if self.cached_trigger_embeds is not None and not is_reg:
|
1221 |
+
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
|
1222 |
+
self.device_torch, dtype=dtype
|
1223 |
+
)
|
1224 |
+
conditional_embeds = concat_prompt_embeds(
|
1225 |
+
[embeds_to_use] * noisy_latents.shape[0]
|
1226 |
+
)
|
1227 |
+
if self.train_config.do_cfg:
|
1228 |
+
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
|
1229 |
+
self.device_torch, dtype=dtype
|
1230 |
+
)
|
1231 |
+
unconditional_embeds = concat_prompt_embeds(
|
1232 |
+
[unconditional_embeds] * noisy_latents.shape[0]
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
if isinstance(self.adapter, CustomAdapter):
|
1236 |
+
self.adapter.is_unconditional_run = False
|
1237 |
+
|
1238 |
+
elif grad_on_text_encoder:
|
1239 |
+
with torch.set_grad_enabled(True):
|
1240 |
+
if isinstance(self.adapter, CustomAdapter):
|
1241 |
+
self.adapter.is_unconditional_run = False
|
1242 |
+
conditional_embeds = self.sd.encode_prompt(
|
1243 |
+
conditioned_prompts, prompt_2,
|
1244 |
+
dropout_prob=self.train_config.prompt_dropout_prob,
|
1245 |
+
long_prompts=self.do_long_prompts).to(
|
1246 |
+
self.device_torch,
|
1247 |
+
dtype=dtype)
|
1248 |
+
|
1249 |
+
if self.train_config.do_cfg:
|
1250 |
+
if isinstance(self.adapter, CustomAdapter):
|
1251 |
+
self.adapter.is_unconditional_run = True
|
1252 |
+
# todo only do one and repeat it
|
1253 |
+
unconditional_embeds = self.sd.encode_prompt(
|
1254 |
+
self.batch_negative_prompt,
|
1255 |
+
self.batch_negative_prompt,
|
1256 |
+
dropout_prob=self.train_config.prompt_dropout_prob,
|
1257 |
+
long_prompts=self.do_long_prompts).to(
|
1258 |
+
self.device_torch,
|
1259 |
+
dtype=dtype)
|
1260 |
+
if isinstance(self.adapter, CustomAdapter):
|
1261 |
+
self.adapter.is_unconditional_run = False
|
1262 |
+
else:
|
1263 |
+
with torch.set_grad_enabled(False):
|
1264 |
+
# make sure it is in eval mode
|
1265 |
+
if isinstance(self.sd.text_encoder, list):
|
1266 |
+
for te in self.sd.text_encoder:
|
1267 |
+
te.eval()
|
1268 |
+
else:
|
1269 |
+
self.sd.text_encoder.eval()
|
1270 |
+
if isinstance(self.adapter, CustomAdapter):
|
1271 |
+
self.adapter.is_unconditional_run = False
|
1272 |
+
conditional_embeds = self.sd.encode_prompt(
|
1273 |
+
conditioned_prompts, prompt_2,
|
1274 |
+
dropout_prob=self.train_config.prompt_dropout_prob,
|
1275 |
+
long_prompts=self.do_long_prompts).to(
|
1276 |
+
self.device_torch,
|
1277 |
+
dtype=dtype)
|
1278 |
+
if self.train_config.do_cfg:
|
1279 |
+
if isinstance(self.adapter, CustomAdapter):
|
1280 |
+
self.adapter.is_unconditional_run = True
|
1281 |
+
unconditional_embeds = self.sd.encode_prompt(
|
1282 |
+
self.batch_negative_prompt,
|
1283 |
+
dropout_prob=self.train_config.prompt_dropout_prob,
|
1284 |
+
long_prompts=self.do_long_prompts).to(
|
1285 |
+
self.device_torch,
|
1286 |
+
dtype=dtype)
|
1287 |
+
if isinstance(self.adapter, CustomAdapter):
|
1288 |
+
self.adapter.is_unconditional_run = False
|
1289 |
+
|
1290 |
+
# detach the embeddings
|
1291 |
+
conditional_embeds = conditional_embeds.detach()
|
1292 |
+
if self.train_config.do_cfg:
|
1293 |
+
unconditional_embeds = unconditional_embeds.detach()
|
1294 |
+
|
1295 |
+
if self.decorator:
|
1296 |
+
conditional_embeds.text_embeds = self.decorator(
|
1297 |
+
conditional_embeds.text_embeds
|
1298 |
+
)
|
1299 |
+
if self.train_config.do_cfg:
|
1300 |
+
unconditional_embeds.text_embeds = self.decorator(
|
1301 |
+
unconditional_embeds.text_embeds,
|
1302 |
+
is_unconditional=True
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
# flush()
|
1306 |
+
pred_kwargs = {}
|
1307 |
+
|
1308 |
+
if has_adapter_img:
|
1309 |
+
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
|
1310 |
+
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
1311 |
+
with torch.set_grad_enabled(self.adapter is not None):
|
1312 |
+
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
1313 |
+
adapter_multiplier = get_adapter_multiplier()
|
1314 |
+
with self.timer('encode_adapter'):
|
1315 |
+
down_block_additional_residuals = adapter(adapter_images)
|
1316 |
+
if self.assistant_adapter:
|
1317 |
+
# not training. detach
|
1318 |
+
down_block_additional_residuals = [
|
1319 |
+
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
1320 |
+
down_block_additional_residuals
|
1321 |
+
]
|
1322 |
+
else:
|
1323 |
+
down_block_additional_residuals = [
|
1324 |
+
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
1325 |
+
down_block_additional_residuals
|
1326 |
+
]
|
1327 |
+
|
1328 |
+
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
1329 |
+
|
1330 |
+
if self.adapter and isinstance(self.adapter, IPAdapter):
|
1331 |
+
with self.timer('encode_adapter_embeds'):
|
1332 |
+
# number of images to do if doing a quad image
|
1333 |
+
quad_count = random.randint(1, 4)
|
1334 |
+
image_size = self.adapter.input_size
|
1335 |
+
if has_clip_image_embeds:
|
1336 |
+
# todo handle reg images better than this
|
1337 |
+
if is_reg:
|
1338 |
+
# get unconditional image embeds from cache
|
1339 |
+
embeds = [
|
1340 |
+
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
1341 |
+
range(noisy_latents.shape[0])
|
1342 |
+
]
|
1343 |
+
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
1344 |
+
embeds,
|
1345 |
+
quad_count=quad_count
|
1346 |
+
)
|
1347 |
+
|
1348 |
+
if self.train_config.do_cfg:
|
1349 |
+
embeds = [
|
1350 |
+
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
1351 |
+
range(noisy_latents.shape[0])
|
1352 |
+
]
|
1353 |
+
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
1354 |
+
embeds,
|
1355 |
+
quad_count=quad_count
|
1356 |
+
)
|
1357 |
+
|
1358 |
+
else:
|
1359 |
+
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
1360 |
+
batch.clip_image_embeds,
|
1361 |
+
quad_count=quad_count
|
1362 |
+
)
|
1363 |
+
if self.train_config.do_cfg:
|
1364 |
+
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
1365 |
+
batch.clip_image_embeds_unconditional,
|
1366 |
+
quad_count=quad_count
|
1367 |
+
)
|
1368 |
+
elif is_reg:
|
1369 |
+
# we will zero it out in the img embedder
|
1370 |
+
clip_images = torch.zeros(
|
1371 |
+
(noisy_latents.shape[0], 3, image_size, image_size),
|
1372 |
+
device=self.device_torch, dtype=dtype
|
1373 |
+
).detach()
|
1374 |
+
# drop will zero it out
|
1375 |
+
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1376 |
+
clip_images,
|
1377 |
+
drop=True,
|
1378 |
+
is_training=True,
|
1379 |
+
has_been_preprocessed=False,
|
1380 |
+
quad_count=quad_count
|
1381 |
+
)
|
1382 |
+
if self.train_config.do_cfg:
|
1383 |
+
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1384 |
+
torch.zeros(
|
1385 |
+
(noisy_latents.shape[0], 3, image_size, image_size),
|
1386 |
+
device=self.device_torch, dtype=dtype
|
1387 |
+
).detach(),
|
1388 |
+
is_training=True,
|
1389 |
+
drop=True,
|
1390 |
+
has_been_preprocessed=False,
|
1391 |
+
quad_count=quad_count
|
1392 |
+
)
|
1393 |
+
elif has_clip_image:
|
1394 |
+
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1395 |
+
clip_images.detach().to(self.device_torch, dtype=dtype),
|
1396 |
+
is_training=True,
|
1397 |
+
has_been_preprocessed=True,
|
1398 |
+
quad_count=quad_count,
|
1399 |
+
# do cfg on clip embeds to normalize the embeddings for when doing cfg
|
1400 |
+
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
1401 |
+
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
1402 |
+
)
|
1403 |
+
if self.train_config.do_cfg:
|
1404 |
+
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
1405 |
+
clip_images.detach().to(self.device_torch, dtype=dtype),
|
1406 |
+
is_training=True,
|
1407 |
+
drop=True,
|
1408 |
+
has_been_preprocessed=True,
|
1409 |
+
quad_count=quad_count
|
1410 |
+
)
|
1411 |
+
else:
|
1412 |
+
print("No Clip Image")
|
1413 |
+
print([file_item.path for file_item in batch.file_items])
|
1414 |
+
raise ValueError("Could not find clip image")
|
1415 |
+
|
1416 |
+
if not self.adapter_config.train_image_encoder:
|
1417 |
+
# we are not training the image encoder, so we need to detach the embeds
|
1418 |
+
conditional_clip_embeds = conditional_clip_embeds.detach()
|
1419 |
+
if self.train_config.do_cfg:
|
1420 |
+
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
1421 |
+
|
1422 |
+
with self.timer('encode_adapter'):
|
1423 |
+
self.adapter.train()
|
1424 |
+
conditional_embeds = self.adapter(
|
1425 |
+
conditional_embeds.detach(),
|
1426 |
+
conditional_clip_embeds,
|
1427 |
+
is_unconditional=False
|
1428 |
+
)
|
1429 |
+
if self.train_config.do_cfg:
|
1430 |
+
unconditional_embeds = self.adapter(
|
1431 |
+
unconditional_embeds.detach(),
|
1432 |
+
unconditional_clip_embeds,
|
1433 |
+
is_unconditional=True
|
1434 |
+
)
|
1435 |
+
else:
|
1436 |
+
# wipe out unconsitional
|
1437 |
+
self.adapter.last_unconditional = None
|
1438 |
+
|
1439 |
+
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
1440 |
+
# pass in our scheduler
|
1441 |
+
self.adapter.noise_scheduler = self.lr_scheduler
|
1442 |
+
if has_clip_image or has_adapter_img:
|
1443 |
+
img_to_use = clip_images if has_clip_image else adapter_images
|
1444 |
+
# currently 0-1 needs to be -1 to 1
|
1445 |
+
reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
|
1446 |
+
self.adapter.set_reference_images(reference_images)
|
1447 |
+
self.adapter.noise_scheduler = self.sd.noise_scheduler
|
1448 |
+
elif is_reg:
|
1449 |
+
self.adapter.set_blank_reference_images(noisy_latents.shape[0])
|
1450 |
+
else:
|
1451 |
+
self.adapter.set_reference_images(None)
|
1452 |
+
|
1453 |
+
prior_pred = None
|
1454 |
+
|
1455 |
+
do_reg_prior = False
|
1456 |
+
# if is_reg and (self.network is not None or self.adapter is not None):
|
1457 |
+
# # we are doing a reg image and we have a network or adapter
|
1458 |
+
# do_reg_prior = True
|
1459 |
+
|
1460 |
+
do_inverted_masked_prior = False
|
1461 |
+
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
1462 |
+
do_inverted_masked_prior = True
|
1463 |
+
|
1464 |
+
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
|
1465 |
+
|
1466 |
+
do_guidance_prior = False
|
1467 |
+
|
1468 |
+
if batch.unconditional_latents is not None:
|
1469 |
+
# for this not that, we need a prior pred to normalize
|
1470 |
+
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
1471 |
+
if guidance_type == 'tnt':
|
1472 |
+
do_guidance_prior = True
|
1473 |
+
|
1474 |
+
if ((
|
1475 |
+
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
1476 |
+
with self.timer('prior predict'):
|
1477 |
+
prior_pred = self.get_prior_prediction(
|
1478 |
+
noisy_latents=noisy_latents,
|
1479 |
+
conditional_embeds=conditional_embeds,
|
1480 |
+
match_adapter_assist=match_adapter_assist,
|
1481 |
+
network_weight_list=network_weight_list,
|
1482 |
+
timesteps=timesteps,
|
1483 |
+
pred_kwargs=pred_kwargs,
|
1484 |
+
noise=noise,
|
1485 |
+
batch=batch,
|
1486 |
+
unconditional_embeds=unconditional_embeds,
|
1487 |
+
conditioned_prompts=conditioned_prompts
|
1488 |
+
)
|
1489 |
+
if prior_pred is not None:
|
1490 |
+
prior_pred = prior_pred.detach()
|
1491 |
+
|
1492 |
+
# do the custom adapter after the prior prediction
|
1493 |
+
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
1494 |
+
quad_count = random.randint(1, 4)
|
1495 |
+
self.adapter.train()
|
1496 |
+
conditional_embeds = self.adapter.condition_encoded_embeds(
|
1497 |
+
tensors_0_1=clip_images,
|
1498 |
+
prompt_embeds=conditional_embeds,
|
1499 |
+
is_training=True,
|
1500 |
+
has_been_preprocessed=True,
|
1501 |
+
quad_count=quad_count
|
1502 |
+
)
|
1503 |
+
if self.train_config.do_cfg and unconditional_embeds is not None:
|
1504 |
+
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
1505 |
+
tensors_0_1=clip_images,
|
1506 |
+
prompt_embeds=unconditional_embeds,
|
1507 |
+
is_training=True,
|
1508 |
+
has_been_preprocessed=True,
|
1509 |
+
is_unconditional=True,
|
1510 |
+
quad_count=quad_count
|
1511 |
+
)
|
1512 |
+
|
1513 |
+
if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None:
|
1514 |
+
self.adapter.add_extra_values(batch.extra_values.detach())
|
1515 |
+
|
1516 |
+
if self.train_config.do_cfg:
|
1517 |
+
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
|
1518 |
+
is_unconditional=True)
|
1519 |
+
|
1520 |
+
if has_adapter_img:
|
1521 |
+
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
|
1522 |
+
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
1523 |
+
if self.train_config.do_cfg:
|
1524 |
+
raise ValueError("ControlNetModel is not supported with CFG")
|
1525 |
+
with torch.set_grad_enabled(self.adapter is not None):
|
1526 |
+
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
1527 |
+
adapter_multiplier = get_adapter_multiplier()
|
1528 |
+
with self.timer('encode_adapter'):
|
1529 |
+
# add_text_embeds is pooled_prompt_embeds for sdxl
|
1530 |
+
added_cond_kwargs = {}
|
1531 |
+
if self.sd.is_xl:
|
1532 |
+
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
|
1533 |
+
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
|
1534 |
+
down_block_res_samples, mid_block_res_sample = adapter(
|
1535 |
+
noisy_latents,
|
1536 |
+
timesteps,
|
1537 |
+
encoder_hidden_states=conditional_embeds.text_embeds,
|
1538 |
+
controlnet_cond=adapter_images,
|
1539 |
+
conditioning_scale=1.0,
|
1540 |
+
guess_mode=False,
|
1541 |
+
added_cond_kwargs=added_cond_kwargs,
|
1542 |
+
return_dict=False,
|
1543 |
+
)
|
1544 |
+
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
1545 |
+
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
1546 |
+
|
1547 |
+
self.before_unet_predict()
|
1548 |
+
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
1549 |
+
if batch.unconditional_latents is not None or self.do_guided_loss:
|
1550 |
+
# do guided loss
|
1551 |
+
loss = self.get_guided_loss(
|
1552 |
+
noisy_latents=noisy_latents,
|
1553 |
+
conditional_embeds=conditional_embeds,
|
1554 |
+
match_adapter_assist=match_adapter_assist,
|
1555 |
+
network_weight_list=network_weight_list,
|
1556 |
+
timesteps=timesteps,
|
1557 |
+
pred_kwargs=pred_kwargs,
|
1558 |
+
batch=batch,
|
1559 |
+
noise=noise,
|
1560 |
+
unconditional_embeds=unconditional_embeds,
|
1561 |
+
mask_multiplier=mask_multiplier,
|
1562 |
+
prior_pred=prior_pred,
|
1563 |
+
)
|
1564 |
+
|
1565 |
+
else:
|
1566 |
+
with self.timer('predict_unet'):
|
1567 |
+
if unconditional_embeds is not None:
|
1568 |
+
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
1569 |
+
noise_pred = self.predict_noise(
|
1570 |
+
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
1571 |
+
timesteps=timesteps,
|
1572 |
+
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
1573 |
+
unconditional_embeds=unconditional_embeds,
|
1574 |
+
**pred_kwargs
|
1575 |
+
)
|
1576 |
+
self.after_unet_predict()
|
1577 |
+
|
1578 |
+
with self.timer('calculate_loss'):
|
1579 |
+
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
1580 |
+
loss = self.calculate_loss(
|
1581 |
+
noise_pred=noise_pred,
|
1582 |
+
noise=noise,
|
1583 |
+
noisy_latents=noisy_latents,
|
1584 |
+
timesteps=timesteps,
|
1585 |
+
batch=batch,
|
1586 |
+
mask_multiplier=mask_multiplier,
|
1587 |
+
prior_pred=prior_pred,
|
1588 |
+
)
|
1589 |
+
# check if nan
|
1590 |
+
if torch.isnan(loss):
|
1591 |
+
print("loss is nan")
|
1592 |
+
loss = torch.zeros_like(loss).requires_grad_(True)
|
1593 |
+
|
1594 |
+
with self.timer('backward'):
|
1595 |
+
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
1596 |
+
loss = loss * loss_multiplier.mean()
|
1597 |
+
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
1598 |
+
# it will destroy the gradients. This is because the network is a context manager
|
1599 |
+
# and will change the multipliers back to 0.0 when exiting. They will be
|
1600 |
+
# 0.0 for the backward pass and the gradients will be 0.0
|
1601 |
+
# I spent weeks on fighting this. DON'T DO IT
|
1602 |
+
# with fsdp_overlap_step_with_backward():
|
1603 |
+
# if self.is_bfloat:
|
1604 |
+
# loss.backward()
|
1605 |
+
# else:
|
1606 |
+
if not self.do_grad_scale:
|
1607 |
+
loss.backward()
|
1608 |
+
else:
|
1609 |
+
self.scaler.scale(loss).backward()
|
1610 |
+
|
1611 |
+
return loss.detach()
|
1612 |
+
# flush()
|
1613 |
+
|
1614 |
+
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
|
1615 |
+
if isinstance(batch, list):
|
1616 |
+
batch_list = batch
|
1617 |
+
else:
|
1618 |
+
batch_list = [batch]
|
1619 |
+
total_loss = None
|
1620 |
+
self.optimizer.zero_grad()
|
1621 |
+
for batch in batch_list:
|
1622 |
+
loss = self.train_single_accumulation(batch)
|
1623 |
+
if total_loss is None:
|
1624 |
+
total_loss = loss
|
1625 |
+
else:
|
1626 |
+
total_loss += loss
|
1627 |
+
if len(batch_list) > 1 and self.model_config.low_vram:
|
1628 |
+
torch.cuda.empty_cache()
|
1629 |
+
|
1630 |
+
|
1631 |
+
if not self.is_grad_accumulation_step:
|
1632 |
+
# fix this for multi params
|
1633 |
+
if self.train_config.optimizer != 'adafactor':
|
1634 |
+
if self.do_grad_scale:
|
1635 |
+
self.scaler.unscale_(self.optimizer)
|
1636 |
+
if isinstance(self.params[0], dict):
|
1637 |
+
for i in range(len(self.params)):
|
1638 |
+
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
1639 |
+
else:
|
1640 |
+
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
1641 |
+
# only step if we are not accumulating
|
1642 |
+
with self.timer('optimizer_step'):
|
1643 |
+
# self.optimizer.step()
|
1644 |
+
if not self.do_grad_scale:
|
1645 |
+
self.optimizer.step()
|
1646 |
+
else:
|
1647 |
+
self.scaler.step(self.optimizer)
|
1648 |
+
self.scaler.update()
|
1649 |
+
|
1650 |
+
self.optimizer.zero_grad(set_to_none=True)
|
1651 |
+
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
1652 |
+
self.adapter.post_weight_update()
|
1653 |
+
if self.ema is not None:
|
1654 |
+
with self.timer('ema_update'):
|
1655 |
+
self.ema.update()
|
1656 |
+
else:
|
1657 |
+
# gradient accumulation. Just a place for breakpoint
|
1658 |
+
pass
|
1659 |
+
|
1660 |
+
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
|
1661 |
+
with self.timer('scheduler_step'):
|
1662 |
+
self.lr_scheduler.step()
|
1663 |
+
|
1664 |
+
if self.embedding is not None:
|
1665 |
+
with self.timer('restore_embeddings'):
|
1666 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
1667 |
+
self.embedding.restore_embeddings()
|
1668 |
+
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
1669 |
+
with self.timer('restore_adapter'):
|
1670 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
1671 |
+
self.adapter.restore_embeddings()
|
1672 |
+
|
1673 |
+
loss_dict = OrderedDict(
|
1674 |
+
{'loss': loss.item()}
|
1675 |
+
)
|
1676 |
+
|
1677 |
+
self.end_of_training_loop()
|
1678 |
+
|
1679 |
+
return loss_dict
|
extensions_built_in/sd_trainer/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
6 |
+
class SDTrainerExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "sd_trainer"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "SD Trainer"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .SDTrainer import SDTrainer
|
19 |
+
return SDTrainer
|
20 |
+
|
21 |
+
|
22 |
+
# for backwards compatability
|
23 |
+
class TextualInversionTrainer(SDTrainerExtension):
|
24 |
+
uid = "textual_inversion_trainer"
|
25 |
+
|
26 |
+
|
27 |
+
AI_TOOLKIT_EXTENSIONS = [
|
28 |
+
# you can put a list of extensions here
|
29 |
+
SDTrainerExtension, TextualInversionTrainer
|
30 |
+
]
|
extensions_built_in/sd_trainer/config/train.example.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: test_v1
|
5 |
+
process:
|
6 |
+
- type: 'textual_inversion_trainer'
|
7 |
+
training_folder: "out/TI"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "out/.tensorboard"
|
11 |
+
embedding:
|
12 |
+
trigger: "your_trigger_here"
|
13 |
+
tokens: 12
|
14 |
+
init_words: "man with short brown hair"
|
15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
16 |
+
save:
|
17 |
+
dtype: float16 # precision to save
|
18 |
+
save_every: 100 # save every this many steps
|
19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
20 |
+
datasets:
|
21 |
+
- folder_path: "/path/to/dataset"
|
22 |
+
caption_ext: "txt"
|
23 |
+
default_caption: "[trigger]"
|
24 |
+
buckets: true
|
25 |
+
resolution: 512
|
26 |
+
train:
|
27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
28 |
+
steps: 3000
|
29 |
+
weight_jitter: 0.0
|
30 |
+
lr: 5e-5
|
31 |
+
train_unet: false
|
32 |
+
gradient_checkpointing: true
|
33 |
+
train_text_encoder: false
|
34 |
+
optimizer: "adamw"
|
35 |
+
# optimizer: "prodigy"
|
36 |
+
optimizer_params:
|
37 |
+
weight_decay: 1e-2
|
38 |
+
lr_scheduler: "constant"
|
39 |
+
max_denoising_steps: 1000
|
40 |
+
batch_size: 4
|
41 |
+
dtype: bf16
|
42 |
+
xformers: true
|
43 |
+
min_snr_gamma: 5.0
|
44 |
+
# skip_first_sample: true
|
45 |
+
noise_offset: 0.0 # not needed for this
|
46 |
+
model:
|
47 |
+
# objective reality v2
|
48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
49 |
+
is_v2: false # for v2 models
|
50 |
+
is_xl: false # for SDXL models
|
51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
52 |
+
sample:
|
53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
54 |
+
sample_every: 100 # sample every this many steps
|
55 |
+
width: 512
|
56 |
+
height: 512
|
57 |
+
prompts:
|
58 |
+
- "photo of [trigger] laughing"
|
59 |
+
- "photo of [trigger] smiling"
|
60 |
+
- "[trigger] close up"
|
61 |
+
- "dark scene [trigger] frozen"
|
62 |
+
- "[trigger] nighttime"
|
63 |
+
- "a painting of [trigger]"
|
64 |
+
- "a drawing of [trigger]"
|
65 |
+
- "a cartoon of [trigger]"
|
66 |
+
- "[trigger] pixar style"
|
67 |
+
- "[trigger] costume"
|
68 |
+
neg: ""
|
69 |
+
seed: 42
|
70 |
+
walk_seed: false
|
71 |
+
guidance_scale: 7
|
72 |
+
sample_steps: 20
|
73 |
+
network_multiplier: 1.0
|
74 |
+
|
75 |
+
logging:
|
76 |
+
log_every: 10 # log every this many steps
|
77 |
+
use_wandb: false # not supported yet
|
78 |
+
verbose: false
|
79 |
+
|
80 |
+
# You can put any information you want here, and it will be saved in the model.
|
81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
82 |
+
# It is saved in the model so be aware of that. The software will include this
|
83 |
+
# plus some other information for you automatically
|
84 |
+
meta:
|
85 |
+
# [name] gets replaced with the name above
|
86 |
+
name: "[name]"
|
87 |
+
# version: '1.0'
|
88 |
+
# creator:
|
89 |
+
# name: Your Name
|
90 |
+
# email: [email protected]
|
91 |
+
# website: https://your.website
|
extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
from collections import OrderedDict
|
4 |
+
import os
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from typing import Optional, Union, List
|
7 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
8 |
+
|
9 |
+
from toolkit.config_modules import ReferenceDatasetConfig
|
10 |
+
from toolkit.data_loader import PairedImageDataset
|
11 |
+
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds, build_latent_image_batch_for_prompt_pair
|
12 |
+
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
13 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
14 |
+
import gc
|
15 |
+
from toolkit import train_tools
|
16 |
+
import torch
|
17 |
+
from jobs.process import BaseSDTrainProcess
|
18 |
+
import random
|
19 |
+
|
20 |
+
import random
|
21 |
+
from collections import OrderedDict
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from toolkit.config_modules import SliderConfig
|
25 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
26 |
+
import gc
|
27 |
+
from toolkit import train_tools
|
28 |
+
from toolkit.prompt_utils import \
|
29 |
+
EncodedPromptPair, ACTION_TYPES_SLIDER, \
|
30 |
+
EncodedAnchor, concat_prompt_pairs, \
|
31 |
+
concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \
|
32 |
+
split_prompt_pairs
|
33 |
+
|
34 |
+
import torch
|
35 |
+
|
36 |
+
|
37 |
+
def flush():
|
38 |
+
torch.cuda.empty_cache()
|
39 |
+
gc.collect()
|
40 |
+
|
41 |
+
|
42 |
+
class UltimateSliderConfig(SliderConfig):
|
43 |
+
def __init__(self, **kwargs):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
46 |
+
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
|
47 |
+
self.img_loss_weight: float = kwargs.get('img_loss_weight', 1.0)
|
48 |
+
self.cfg_loss_weight: float = kwargs.get('cfg_loss_weight', 1.0)
|
49 |
+
self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
50 |
+
|
51 |
+
|
52 |
+
class UltimateSliderTrainerProcess(BaseSDTrainProcess):
|
53 |
+
sd: StableDiffusion
|
54 |
+
data_loader: DataLoader = None
|
55 |
+
|
56 |
+
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
57 |
+
super().__init__(process_id, job, config, **kwargs)
|
58 |
+
self.prompt_txt_list = None
|
59 |
+
self.step_num = 0
|
60 |
+
self.start_step = 0
|
61 |
+
self.device = self.get_conf('device', self.job.device)
|
62 |
+
self.device_torch = torch.device(self.device)
|
63 |
+
self.slider_config = UltimateSliderConfig(**self.get_conf('slider', {}))
|
64 |
+
|
65 |
+
self.prompt_cache = PromptEmbedsCache()
|
66 |
+
self.prompt_pairs: list[EncodedPromptPair] = []
|
67 |
+
self.anchor_pairs: list[EncodedAnchor] = []
|
68 |
+
# keep track of prompt chunk size
|
69 |
+
self.prompt_chunk_size = 1
|
70 |
+
|
71 |
+
# store a list of all the prompts from the dataset so we can cache it
|
72 |
+
self.dataset_prompts = []
|
73 |
+
self.train_with_dataset = self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0
|
74 |
+
|
75 |
+
def load_datasets(self):
|
76 |
+
if self.data_loader is None and \
|
77 |
+
self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0:
|
78 |
+
print(f"Loading datasets")
|
79 |
+
datasets = []
|
80 |
+
for dataset in self.slider_config.datasets:
|
81 |
+
print(f" - Dataset: {dataset.pair_folder}")
|
82 |
+
config = {
|
83 |
+
'path': dataset.pair_folder,
|
84 |
+
'size': dataset.size,
|
85 |
+
'default_prompt': dataset.target_class,
|
86 |
+
'network_weight': dataset.network_weight,
|
87 |
+
'pos_weight': dataset.pos_weight,
|
88 |
+
'neg_weight': dataset.neg_weight,
|
89 |
+
'pos_folder': dataset.pos_folder,
|
90 |
+
'neg_folder': dataset.neg_folder,
|
91 |
+
}
|
92 |
+
image_dataset = PairedImageDataset(config)
|
93 |
+
datasets.append(image_dataset)
|
94 |
+
|
95 |
+
# capture all the prompts from it so we can cache the embeds
|
96 |
+
self.dataset_prompts += image_dataset.get_all_prompts()
|
97 |
+
|
98 |
+
concatenated_dataset = ConcatDataset(datasets)
|
99 |
+
self.data_loader = DataLoader(
|
100 |
+
concatenated_dataset,
|
101 |
+
batch_size=self.train_config.batch_size,
|
102 |
+
shuffle=True,
|
103 |
+
num_workers=2
|
104 |
+
)
|
105 |
+
|
106 |
+
def before_model_load(self):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def hook_before_train_loop(self):
|
110 |
+
# load any datasets if they were passed
|
111 |
+
self.load_datasets()
|
112 |
+
|
113 |
+
# read line by line from file
|
114 |
+
if self.slider_config.prompt_file:
|
115 |
+
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
116 |
+
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
|
117 |
+
self.prompt_txt_list = f.readlines()
|
118 |
+
# clean empty lines
|
119 |
+
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
120 |
+
|
121 |
+
self.print(f"Found {len(self.prompt_txt_list)} prompts.")
|
122 |
+
|
123 |
+
if not self.slider_config.prompt_tensors:
|
124 |
+
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
|
125 |
+
# shuffle
|
126 |
+
random.shuffle(self.prompt_txt_list)
|
127 |
+
# trim to max steps
|
128 |
+
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
129 |
+
# trim list to our max steps
|
130 |
+
|
131 |
+
cache = PromptEmbedsCache()
|
132 |
+
|
133 |
+
# get encoded latents for our prompts
|
134 |
+
with torch.no_grad():
|
135 |
+
# list of neutrals. Can come from file or be empty
|
136 |
+
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
|
137 |
+
|
138 |
+
# build the prompts to cache
|
139 |
+
prompts_to_cache = []
|
140 |
+
for neutral in neutral_list:
|
141 |
+
for target in self.slider_config.targets:
|
142 |
+
prompt_list = [
|
143 |
+
f"{target.target_class}", # target_class
|
144 |
+
f"{target.target_class} {neutral}", # target_class with neutral
|
145 |
+
f"{target.positive}", # positive_target
|
146 |
+
f"{target.positive} {neutral}", # positive_target with neutral
|
147 |
+
f"{target.negative}", # negative_target
|
148 |
+
f"{target.negative} {neutral}", # negative_target with neutral
|
149 |
+
f"{neutral}", # neutral
|
150 |
+
f"{target.positive} {target.negative}", # both targets
|
151 |
+
f"{target.negative} {target.positive}", # both targets reverse
|
152 |
+
]
|
153 |
+
prompts_to_cache += prompt_list
|
154 |
+
|
155 |
+
# remove duplicates
|
156 |
+
prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
|
157 |
+
|
158 |
+
# trim to max steps if max steps is lower than prompt count
|
159 |
+
prompts_to_cache = prompts_to_cache[:self.train_config.steps]
|
160 |
+
|
161 |
+
if len(self.dataset_prompts) > 0:
|
162 |
+
# add the prompts from the dataset
|
163 |
+
prompts_to_cache += self.dataset_prompts
|
164 |
+
|
165 |
+
# encode them
|
166 |
+
cache = encode_prompts_to_cache(
|
167 |
+
prompt_list=prompts_to_cache,
|
168 |
+
sd=self.sd,
|
169 |
+
cache=cache,
|
170 |
+
prompt_tensor_file=self.slider_config.prompt_tensors
|
171 |
+
)
|
172 |
+
|
173 |
+
prompt_pairs = []
|
174 |
+
prompt_batches = []
|
175 |
+
for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False):
|
176 |
+
for target in self.slider_config.targets:
|
177 |
+
prompt_pair_batch = build_prompt_pair_batch_from_cache(
|
178 |
+
cache=cache,
|
179 |
+
target=target,
|
180 |
+
neutral=neutral,
|
181 |
+
|
182 |
+
)
|
183 |
+
if self.slider_config.batch_full_slide:
|
184 |
+
# concat the prompt pairs
|
185 |
+
# this allows us to run the entire 4 part process in one shot (for slider)
|
186 |
+
self.prompt_chunk_size = 4
|
187 |
+
concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu')
|
188 |
+
prompt_pairs += [concat_prompt_pair_batch]
|
189 |
+
else:
|
190 |
+
self.prompt_chunk_size = 1
|
191 |
+
# do them one at a time (probably not necessary after new optimizations)
|
192 |
+
prompt_pairs += [x.to('cpu') for x in prompt_pair_batch]
|
193 |
+
|
194 |
+
# move to cpu to save vram
|
195 |
+
# We don't need text encoder anymore, but keep it on cpu for sampling
|
196 |
+
# if text encoder is list
|
197 |
+
if isinstance(self.sd.text_encoder, list):
|
198 |
+
for encoder in self.sd.text_encoder:
|
199 |
+
encoder.to("cpu")
|
200 |
+
else:
|
201 |
+
self.sd.text_encoder.to("cpu")
|
202 |
+
self.prompt_cache = cache
|
203 |
+
self.prompt_pairs = prompt_pairs
|
204 |
+
# end hook_before_train_loop
|
205 |
+
|
206 |
+
# move vae to device so we can encode on the fly
|
207 |
+
# todo cache latents
|
208 |
+
self.sd.vae.to(self.device_torch)
|
209 |
+
self.sd.vae.eval()
|
210 |
+
self.sd.vae.requires_grad_(False)
|
211 |
+
|
212 |
+
if self.train_config.gradient_checkpointing:
|
213 |
+
# may get disabled elsewhere
|
214 |
+
self.sd.unet.enable_gradient_checkpointing()
|
215 |
+
|
216 |
+
flush()
|
217 |
+
# end hook_before_train_loop
|
218 |
+
|
219 |
+
def hook_train_loop(self, batch):
|
220 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
221 |
+
|
222 |
+
with torch.no_grad():
|
223 |
+
### LOOP SETUP ###
|
224 |
+
noise_scheduler = self.sd.noise_scheduler
|
225 |
+
optimizer = self.optimizer
|
226 |
+
lr_scheduler = self.lr_scheduler
|
227 |
+
|
228 |
+
### TARGET_PROMPTS ###
|
229 |
+
# get a random pair
|
230 |
+
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
231 |
+
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
232 |
+
]
|
233 |
+
# move to device and dtype
|
234 |
+
prompt_pair.to(self.device_torch, dtype=dtype)
|
235 |
+
|
236 |
+
### PREP REFERENCE IMAGES ###
|
237 |
+
|
238 |
+
imgs, prompts, network_weights = batch
|
239 |
+
network_pos_weight, network_neg_weight = network_weights
|
240 |
+
|
241 |
+
if isinstance(network_pos_weight, torch.Tensor):
|
242 |
+
network_pos_weight = network_pos_weight.item()
|
243 |
+
if isinstance(network_neg_weight, torch.Tensor):
|
244 |
+
network_neg_weight = network_neg_weight.item()
|
245 |
+
|
246 |
+
# get an array of random floats between -weight_jitter and weight_jitter
|
247 |
+
weight_jitter = self.slider_config.weight_jitter
|
248 |
+
if weight_jitter > 0.0:
|
249 |
+
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
250 |
+
network_pos_weight += jitter_list
|
251 |
+
network_neg_weight += (jitter_list * -1.0)
|
252 |
+
|
253 |
+
# if items in network_weight list are tensors, convert them to floats
|
254 |
+
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
255 |
+
# split batched images in half so left is negative and right is positive
|
256 |
+
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
257 |
+
|
258 |
+
height = positive_images.shape[2]
|
259 |
+
width = positive_images.shape[3]
|
260 |
+
batch_size = positive_images.shape[0]
|
261 |
+
|
262 |
+
positive_latents = self.sd.encode_images(positive_images)
|
263 |
+
negative_latents = self.sd.encode_images(negative_images)
|
264 |
+
|
265 |
+
self.sd.noise_scheduler.set_timesteps(
|
266 |
+
self.train_config.max_denoising_steps, device=self.device_torch
|
267 |
+
)
|
268 |
+
|
269 |
+
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
270 |
+
current_timestep_index = timesteps.item()
|
271 |
+
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
272 |
+
timesteps = timesteps.long()
|
273 |
+
|
274 |
+
# get noise
|
275 |
+
noise_positive = self.sd.get_latent_noise(
|
276 |
+
pixel_height=height,
|
277 |
+
pixel_width=width,
|
278 |
+
batch_size=batch_size,
|
279 |
+
noise_offset=self.train_config.noise_offset,
|
280 |
+
).to(self.device_torch, dtype=dtype)
|
281 |
+
|
282 |
+
noise_negative = noise_positive.clone()
|
283 |
+
|
284 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
285 |
+
# (this is the forward diffusion process)
|
286 |
+
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
|
287 |
+
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
|
288 |
+
|
289 |
+
### CFG SLIDER TRAINING PREP ###
|
290 |
+
|
291 |
+
# get CFG txt latents
|
292 |
+
noisy_cfg_latents = build_latent_image_batch_for_prompt_pair(
|
293 |
+
pos_latent=noisy_positive_latents,
|
294 |
+
neg_latent=noisy_negative_latents,
|
295 |
+
prompt_pair=prompt_pair,
|
296 |
+
prompt_chunk_size=self.prompt_chunk_size,
|
297 |
+
)
|
298 |
+
noisy_cfg_latents.requires_grad = False
|
299 |
+
|
300 |
+
assert not self.network.is_active
|
301 |
+
|
302 |
+
# 4.20 GB RAM for 512x512
|
303 |
+
positive_latents = self.sd.predict_noise(
|
304 |
+
latents=noisy_cfg_latents,
|
305 |
+
text_embeddings=train_tools.concat_prompt_embeddings(
|
306 |
+
prompt_pair.positive_target, # negative prompt
|
307 |
+
prompt_pair.negative_target, # positive prompt
|
308 |
+
self.train_config.batch_size,
|
309 |
+
),
|
310 |
+
timestep=current_timestep,
|
311 |
+
guidance_scale=1.0
|
312 |
+
)
|
313 |
+
positive_latents.requires_grad = False
|
314 |
+
|
315 |
+
neutral_latents = self.sd.predict_noise(
|
316 |
+
latents=noisy_cfg_latents,
|
317 |
+
text_embeddings=train_tools.concat_prompt_embeddings(
|
318 |
+
prompt_pair.positive_target, # negative prompt
|
319 |
+
prompt_pair.empty_prompt, # positive prompt (normally neutral
|
320 |
+
self.train_config.batch_size,
|
321 |
+
),
|
322 |
+
timestep=current_timestep,
|
323 |
+
guidance_scale=1.0
|
324 |
+
)
|
325 |
+
neutral_latents.requires_grad = False
|
326 |
+
|
327 |
+
unconditional_latents = self.sd.predict_noise(
|
328 |
+
latents=noisy_cfg_latents,
|
329 |
+
text_embeddings=train_tools.concat_prompt_embeddings(
|
330 |
+
prompt_pair.positive_target, # negative prompt
|
331 |
+
prompt_pair.positive_target, # positive prompt
|
332 |
+
self.train_config.batch_size,
|
333 |
+
),
|
334 |
+
timestep=current_timestep,
|
335 |
+
guidance_scale=1.0
|
336 |
+
)
|
337 |
+
unconditional_latents.requires_grad = False
|
338 |
+
|
339 |
+
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
|
340 |
+
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
|
341 |
+
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
|
342 |
+
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
|
343 |
+
noisy_cfg_latents_chunks = torch.chunk(noisy_cfg_latents, self.prompt_chunk_size, dim=0)
|
344 |
+
assert len(prompt_pair_chunks) == len(noisy_cfg_latents_chunks)
|
345 |
+
|
346 |
+
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
347 |
+
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
348 |
+
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
349 |
+
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
|
350 |
+
|
351 |
+
flush()
|
352 |
+
|
353 |
+
loss_float = None
|
354 |
+
loss_mirror_float = None
|
355 |
+
|
356 |
+
self.optimizer.zero_grad()
|
357 |
+
noisy_latents.requires_grad = False
|
358 |
+
|
359 |
+
# TODO allow both processed to train text encoder, for now, we just to unet and cache all text encodes
|
360 |
+
# if training text encoder enable grads, else do context of no grad
|
361 |
+
# with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
362 |
+
# # text encoding
|
363 |
+
# embedding_list = []
|
364 |
+
# # embed the prompts
|
365 |
+
# for prompt in prompts:
|
366 |
+
# embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
367 |
+
# embedding_list.append(embedding)
|
368 |
+
# conditional_embeds = concat_prompt_embeds(embedding_list)
|
369 |
+
# conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
370 |
+
|
371 |
+
if self.train_with_dataset:
|
372 |
+
embedding_list = []
|
373 |
+
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
374 |
+
for prompt in prompts:
|
375 |
+
# get embedding form cache
|
376 |
+
embedding = self.prompt_cache[prompt]
|
377 |
+
embedding = embedding.to(self.device_torch, dtype=dtype)
|
378 |
+
embedding_list.append(embedding)
|
379 |
+
conditional_embeds = concat_prompt_embeds(embedding_list)
|
380 |
+
# double up so we can do both sides of the slider
|
381 |
+
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
382 |
+
else:
|
383 |
+
# throw error. Not supported yet
|
384 |
+
raise Exception("Datasets and targets required for ultimate slider")
|
385 |
+
|
386 |
+
if self.model_config.is_xl:
|
387 |
+
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
388 |
+
network_multiplier_list = network_multiplier
|
389 |
+
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
390 |
+
noise_list = torch.chunk(noise, 2, dim=0)
|
391 |
+
timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
392 |
+
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
393 |
+
else:
|
394 |
+
network_multiplier_list = [network_multiplier]
|
395 |
+
noisy_latent_list = [noisy_latents]
|
396 |
+
noise_list = [noise]
|
397 |
+
timesteps_list = [timesteps]
|
398 |
+
conditional_embeds_list = [conditional_embeds]
|
399 |
+
|
400 |
+
## DO REFERENCE IMAGE TRAINING ##
|
401 |
+
|
402 |
+
reference_image_losses = []
|
403 |
+
# allow to chunk it out to save vram
|
404 |
+
for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip(
|
405 |
+
network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list
|
406 |
+
):
|
407 |
+
with self.network:
|
408 |
+
assert self.network.is_active
|
409 |
+
|
410 |
+
self.network.multiplier = network_multiplier
|
411 |
+
|
412 |
+
noise_pred = self.sd.predict_noise(
|
413 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
414 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
415 |
+
timestep=timesteps,
|
416 |
+
)
|
417 |
+
noise = noise.to(self.device_torch, dtype=dtype)
|
418 |
+
|
419 |
+
if self.sd.prediction_type == 'v_prediction':
|
420 |
+
# v-parameterization training
|
421 |
+
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
422 |
+
else:
|
423 |
+
target = noise
|
424 |
+
|
425 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
426 |
+
loss = loss.mean([1, 2, 3])
|
427 |
+
|
428 |
+
# todo add snr gamma here
|
429 |
+
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
430 |
+
# add min_snr_gamma
|
431 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
432 |
+
|
433 |
+
loss = loss.mean()
|
434 |
+
loss = loss * self.slider_config.img_loss_weight
|
435 |
+
loss_slide_float = loss.item()
|
436 |
+
|
437 |
+
loss_float = loss.item()
|
438 |
+
reference_image_losses.append(loss_float)
|
439 |
+
|
440 |
+
# back propagate loss to free ram
|
441 |
+
loss.backward()
|
442 |
+
flush()
|
443 |
+
|
444 |
+
## DO CFG SLIDER TRAINING ##
|
445 |
+
|
446 |
+
cfg_loss_list = []
|
447 |
+
|
448 |
+
with self.network:
|
449 |
+
assert self.network.is_active
|
450 |
+
for prompt_pair_chunk, \
|
451 |
+
noisy_cfg_latent_chunk, \
|
452 |
+
positive_latents_chunk, \
|
453 |
+
neutral_latents_chunk, \
|
454 |
+
unconditional_latents_chunk \
|
455 |
+
in zip(
|
456 |
+
prompt_pair_chunks,
|
457 |
+
noisy_cfg_latents_chunks,
|
458 |
+
positive_latents_chunks,
|
459 |
+
neutral_latents_chunks,
|
460 |
+
unconditional_latents_chunks,
|
461 |
+
):
|
462 |
+
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
463 |
+
|
464 |
+
target_latents = self.sd.predict_noise(
|
465 |
+
latents=noisy_cfg_latent_chunk,
|
466 |
+
text_embeddings=train_tools.concat_prompt_embeddings(
|
467 |
+
prompt_pair_chunk.positive_target, # negative prompt
|
468 |
+
prompt_pair_chunk.target_class, # positive prompt
|
469 |
+
self.train_config.batch_size,
|
470 |
+
),
|
471 |
+
timestep=current_timestep,
|
472 |
+
guidance_scale=1.0
|
473 |
+
)
|
474 |
+
|
475 |
+
guidance_scale = 1.0
|
476 |
+
|
477 |
+
offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk)
|
478 |
+
|
479 |
+
# make offset multiplier based on actions
|
480 |
+
offset_multiplier_list = []
|
481 |
+
for action in prompt_pair_chunk.action_list:
|
482 |
+
if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE:
|
483 |
+
offset_multiplier_list += [-1.0]
|
484 |
+
elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE:
|
485 |
+
offset_multiplier_list += [1.0]
|
486 |
+
|
487 |
+
offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype)
|
488 |
+
# make offset multiplier match rank of offset
|
489 |
+
offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1)
|
490 |
+
offset *= offset_multiplier
|
491 |
+
|
492 |
+
offset_neutral = neutral_latents_chunk
|
493 |
+
# offsets are already adjusted on a per-batch basis
|
494 |
+
offset_neutral += offset
|
495 |
+
|
496 |
+
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
497 |
+
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
498 |
+
loss = loss.mean([1, 2, 3])
|
499 |
+
|
500 |
+
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
501 |
+
# match batch size
|
502 |
+
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
503 |
+
# add min_snr_gamma
|
504 |
+
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
|
505 |
+
self.train_config.min_snr_gamma)
|
506 |
+
|
507 |
+
loss = loss.mean() * prompt_pair_chunk.weight * self.slider_config.cfg_loss_weight
|
508 |
+
|
509 |
+
loss.backward()
|
510 |
+
cfg_loss_list.append(loss.item())
|
511 |
+
del target_latents
|
512 |
+
del offset_neutral
|
513 |
+
del loss
|
514 |
+
flush()
|
515 |
+
|
516 |
+
# apply gradients
|
517 |
+
optimizer.step()
|
518 |
+
lr_scheduler.step()
|
519 |
+
|
520 |
+
# reset network
|
521 |
+
self.network.multiplier = 1.0
|
522 |
+
|
523 |
+
reference_image_loss = sum(reference_image_losses) / len(reference_image_losses) if len(
|
524 |
+
reference_image_losses) > 0 else 0.0
|
525 |
+
cfg_loss = sum(cfg_loss_list) / len(cfg_loss_list) if len(cfg_loss_list) > 0 else 0.0
|
526 |
+
|
527 |
+
loss_dict = OrderedDict({
|
528 |
+
'loss/img': reference_image_loss,
|
529 |
+
'loss/cfg': cfg_loss,
|
530 |
+
})
|
531 |
+
|
532 |
+
return loss_dict
|
533 |
+
# end hook_train_loop
|
extensions_built_in/ultimate_slider_trainer/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# We make a subclass of Extension
|
6 |
+
class UltimateSliderTrainer(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "ultimate_slider_trainer"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Ultimate Slider Trainer"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess
|
19 |
+
return UltimateSliderTrainerProcess
|
20 |
+
|
21 |
+
|
22 |
+
AI_TOOLKIT_EXTENSIONS = [
|
23 |
+
# you can put a list of extensions here
|
24 |
+
UltimateSliderTrainer
|
25 |
+
]
|
extensions_built_in/ultimate_slider_trainer/config/train.example.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: example_name
|
5 |
+
process:
|
6 |
+
- type: 'image_reference_slider_trainer'
|
7 |
+
training_folder: "/mnt/Train/out/LoRA"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "/home/jaret/Dev/.tensorboard"
|
11 |
+
network:
|
12 |
+
type: "lora"
|
13 |
+
linear: 8
|
14 |
+
linear_alpha: 8
|
15 |
+
train:
|
16 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
17 |
+
steps: 5000
|
18 |
+
lr: 1e-4
|
19 |
+
train_unet: true
|
20 |
+
gradient_checkpointing: true
|
21 |
+
train_text_encoder: true
|
22 |
+
optimizer: "adamw"
|
23 |
+
optimizer_params:
|
24 |
+
weight_decay: 1e-2
|
25 |
+
lr_scheduler: "constant"
|
26 |
+
max_denoising_steps: 1000
|
27 |
+
batch_size: 1
|
28 |
+
dtype: bf16
|
29 |
+
xformers: true
|
30 |
+
skip_first_sample: true
|
31 |
+
noise_offset: 0.0
|
32 |
+
model:
|
33 |
+
name_or_path: "/path/to/model.safetensors"
|
34 |
+
is_v2: false # for v2 models
|
35 |
+
is_xl: false # for SDXL models
|
36 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
37 |
+
save:
|
38 |
+
dtype: float16 # precision to save
|
39 |
+
save_every: 1000 # save every this many steps
|
40 |
+
max_step_saves_to_keep: 2 # only affects step counts
|
41 |
+
sample:
|
42 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
43 |
+
sample_every: 100 # sample every this many steps
|
44 |
+
width: 512
|
45 |
+
height: 512
|
46 |
+
prompts:
|
47 |
+
- "photo of a woman with red hair taking a selfie --m -3"
|
48 |
+
- "photo of a woman with red hair taking a selfie --m -1"
|
49 |
+
- "photo of a woman with red hair taking a selfie --m 1"
|
50 |
+
- "photo of a woman with red hair taking a selfie --m 3"
|
51 |
+
- "close up photo of a man smiling at the camera, in a tank top --m -3"
|
52 |
+
- "close up photo of a man smiling at the camera, in a tank top--m -1"
|
53 |
+
- "close up photo of a man smiling at the camera, in a tank top --m 1"
|
54 |
+
- "close up photo of a man smiling at the camera, in a tank top --m 3"
|
55 |
+
- "photo of a blonde woman smiling, barista --m -3"
|
56 |
+
- "photo of a blonde woman smiling, barista --m -1"
|
57 |
+
- "photo of a blonde woman smiling, barista --m 1"
|
58 |
+
- "photo of a blonde woman smiling, barista --m 3"
|
59 |
+
- "photo of a Christina Hendricks --m -1"
|
60 |
+
- "photo of a Christina Hendricks --m -1"
|
61 |
+
- "photo of a Christina Hendricks --m 1"
|
62 |
+
- "photo of a Christina Hendricks --m 3"
|
63 |
+
- "photo of a Christina Ricci --m -3"
|
64 |
+
- "photo of a Christina Ricci --m -1"
|
65 |
+
- "photo of a Christina Ricci --m 1"
|
66 |
+
- "photo of a Christina Ricci --m 3"
|
67 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
68 |
+
seed: 42
|
69 |
+
walk_seed: false
|
70 |
+
guidance_scale: 7
|
71 |
+
sample_steps: 20
|
72 |
+
network_multiplier: 1.0
|
73 |
+
|
74 |
+
logging:
|
75 |
+
log_every: 10 # log every this many steps
|
76 |
+
use_wandb: false # not supported yet
|
77 |
+
verbose: false
|
78 |
+
|
79 |
+
slider:
|
80 |
+
datasets:
|
81 |
+
- pair_folder: "/path/to/folder/side/by/side/images"
|
82 |
+
network_weight: 2.0
|
83 |
+
target_class: "" # only used as default if caption txt are not present
|
84 |
+
size: 512
|
85 |
+
- pair_folder: "/path/to/folder/side/by/side/images"
|
86 |
+
network_weight: 4.0
|
87 |
+
target_class: "" # only used as default if caption txt are not present
|
88 |
+
size: 512
|
89 |
+
|
90 |
+
|
91 |
+
# you can put any information you want here, and it will be saved in the model
|
92 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
93 |
+
# in the metadata. The software will include this plus some other information
|
94 |
+
meta:
|
95 |
+
name: "[name]" # [name] gets replaced with the name above
|
96 |
+
description: A short description of your model
|
97 |
+
trigger_words:
|
98 |
+
- put
|
99 |
+
- trigger
|
100 |
+
- words
|
101 |
+
- here
|
102 |
+
version: '0.1'
|
103 |
+
creator:
|
104 |
+
name: Your Name
|
105 |
+
email: [email protected]
|
106 |
+
website: https://yourwebsite.com
|
107 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|