AINxtGen commited on
Commit
c22961b
1 Parent(s): 6f61bd9
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +181 -0
  2. .gitmodules +12 -0
  3. FAQ.md +10 -0
  4. LICENSE +21 -0
  5. README.md +40 -12
  6. README_origin.md +468 -0
  7. app.py +6 -0
  8. assets/glif.svg +40 -0
  9. assets/lora_ease_ui.png +0 -0
  10. build_and_push_docker.yaml +8 -0
  11. config/examples/extract.example.yml +75 -0
  12. config/examples/generate.example.yaml +60 -0
  13. config/examples/mod_lora_scale.yaml +48 -0
  14. config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
  15. config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
  16. config/examples/train_lora_flux_24gb.yaml +96 -0
  17. config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
  18. config/examples/train_lora_sd35_large_24gb.yaml +97 -0
  19. config/examples/train_slider.example.yml +230 -0
  20. docker/Dockerfile +31 -0
  21. extensions/example/ExampleMergeModels.py +129 -0
  22. extensions/example/__init__.py +25 -0
  23. extensions/example/config/config.example.yaml +48 -0
  24. extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
  25. extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
  26. extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
  27. extensions_built_in/advanced_generator/__init__.py +59 -0
  28. extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
  29. extensions_built_in/concept_replacer/ConceptReplacer.py +151 -0
  30. extensions_built_in/concept_replacer/__init__.py +26 -0
  31. extensions_built_in/concept_replacer/config/train.example.yaml +91 -0
  32. extensions_built_in/dataset_tools/DatasetTools.py +20 -0
  33. extensions_built_in/dataset_tools/SuperTagger.py +196 -0
  34. extensions_built_in/dataset_tools/SyncFromCollection.py +131 -0
  35. extensions_built_in/dataset_tools/__init__.py +43 -0
  36. extensions_built_in/dataset_tools/tools/caption.py +53 -0
  37. extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py +187 -0
  38. extensions_built_in/dataset_tools/tools/fuyu_utils.py +66 -0
  39. extensions_built_in/dataset_tools/tools/image_tools.py +49 -0
  40. extensions_built_in/dataset_tools/tools/llava_utils.py +85 -0
  41. extensions_built_in/dataset_tools/tools/sync_tools.py +279 -0
  42. extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +235 -0
  43. extensions_built_in/image_reference_slider_trainer/__init__.py +25 -0
  44. extensions_built_in/image_reference_slider_trainer/config/train.example.yaml +107 -0
  45. extensions_built_in/sd_trainer/SDTrainer.py +1679 -0
  46. extensions_built_in/sd_trainer/__init__.py +30 -0
  47. extensions_built_in/sd_trainer/config/train.example.yaml +91 -0
  48. extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py +533 -0
  49. extensions_built_in/ultimate_slider_trainer/__init__.py +25 -0
  50. extensions_built_in/ultimate_slider_trainer/config/train.example.yaml +107 -0
.gitignore ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ /env.sh
163
+ /models
164
+ /custom/*
165
+ !/custom/.gitkeep
166
+ /.tmp
167
+ /venv.bkp
168
+ /venv.*
169
+ /config/*
170
+ !/config/examples
171
+ !/config/_PUT_YOUR_CONFIGS_HERE).txt
172
+ /output/*
173
+ !/output/.gitkeep
174
+ /extensions/*
175
+ !/extensions/example
176
+ /temp
177
+ /wandb
178
+ .vscode/settings.json
179
+ .DS_Store
180
+ ._.DS_Store
181
+ merge_file.py
.gitmodules ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "repositories/sd-scripts"]
2
+ path = repositories/sd-scripts
3
+ url = https://github.com/kohya-ss/sd-scripts.git
4
+ [submodule "repositories/leco"]
5
+ path = repositories/leco
6
+ url = https://github.com/p1atdev/LECO
7
+ [submodule "repositories/batch_annotator"]
8
+ path = repositories/batch_annotator
9
+ url = https://github.com/ostris/batch-annotator
10
+ [submodule "repositories/ipadapter"]
11
+ path = repositories/ipadapter
12
+ url = https://github.com/tencent-ailab/IP-Adapter.git
FAQ.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ WIP. Will continue to add things as they are needed.
4
+
5
+ ## FLUX.1 Training
6
+
7
+ #### How much VRAM is required to train a lora on FLUX.1?
8
+
9
+ 24GB minimum is required.
10
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ostris, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,40 @@
1
- ---
2
- title: Flux LoRA Trainning On Modal
3
- emoji: 🏢
4
- colorFrom: blue
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX LoRA Training on Modal
2
+
3
+ ## IMPORTANT - READ THIS
4
+
5
+ This Space allows you to train LoRA models for FLUX using Modal's GPU resources. You **must** use your own API tokens for this Space to function.
6
+
7
+ ## Setup Instructions:
8
+
9
+ 1. **Create Accounts:** Create free accounts on [Hugging Face](https://huggingface.co), [Modal](https://modal.com), and [Weights & Biases](https://wandb.ai) (if you want to save training info).
10
+ 2. **Get API Tokens:**
11
+ * **Hugging Face:** Obtain a "write" access token from your Hugging Face [settings/tokens](https://huggingface.co/settings/tokens).
12
+ * **Modal:** Get your Modal API token from your [Modal dashboard](https://modal.com/).
13
+ * **WandB:** Generate a WandB API key from your Weights & Biases settings if you plan to use WandB.
14
+ 3. **Duplicate This Space:** Duplicate (clone) this Hugging Face Space to your own account.
15
+ 4. **Add API Tokens as Secrets:** In your duplicated space, navigate to "Settings" -> "Variables and Secrets" and add:
16
+ * `HF_TOKEN`: Your Hugging Face write access token.
17
+ * `WANDB_API_KEY`: Your Weights & Biases API key.
18
+ 5. **Upload your dataset**
19
+ * You must upload your dataset to the `/root/ai-toolkit` folder. The images can be in a subfolder but that folder's path needs to begin with `/root/ai-toolkit/`. For instance: `/root/ai-toolkit/my-dataset`
20
+ * Make sure the image file names match with a corresponding text caption in the same folder. `image1.jpg` and `image1.txt`
21
+ * You can upload a zip file of your dataset, or just a collection of images and text files.
22
+ 6. **Customize and Train:**
23
+ * Go to the `App` tab, and use the Gradio interface to train your LoRA.
24
+ * Enter the required information, dataset, and configure the training parameters.
25
+ * Choose to upload a zip file or multiple images
26
+ * Make sure your image file names match with a corresponding `.txt` file.
27
+ * Click the "Start Training" button and wait for the training to complete (check Modal for logs, WandB for training data).
28
+ * The UI will automatically upload the file(s) to the Modal compute environment
29
+ 7. **View Results:**
30
+ * Trained LoRA models will be automatically pushed to your Hugging Face account if you enable the option and have the necessary write token set
31
+ * Samples, Logs, optimizer and other training information will be stored on WandB if enabled.
32
+ * Models, optimizer, and samples are always stored in `Storage > flux-lora-models` on Modal.
33
+
34
+ ## Notes
35
+
36
+ * Training data will always be located in a dataset folder under `/root/ai-toolkit/your-dataset` in your config, this is required to train correctly. If you upload a folder name `my-dataset` as a zip, the folder path to reference will be: `/root/ai-toolkit/my-dataset`.
37
+ * Make sure the images have a corresponding text file with the same name
38
+ * Training and downloading samples in this model will take a while be patient, especially in low vram mode
39
+
40
+ If you encounter any problems, please open a new issue on the [ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) github
README_origin.md ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Toolkit by Ostris
2
+
3
+ ## IMPORTANT NOTE - READ THIS
4
+ This is my research repo. I do a lot of experiments in it and it is possible that I will break things.
5
+ If something breaks, checkout an earlier commit. This repo can train a lot of things, and it is
6
+ hard to keep up with all of them.
7
+
8
+ ## Support my work
9
+
10
+ <a href="https://glif.app" target="_blank">
11
+ <img alt="glif.app" src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/glif.svg?v=1" width="256" height="auto">
12
+ </a>
13
+
14
+
15
+ My work on this project would not be possible without the amazing support of [Glif](https://glif.app/) and everyone on the
16
+ team. If you want to support me, support Glif. [Join the site](https://glif.app/),
17
+ [Join us on Discord](https://discord.com/invite/nuR9zZ2nsh), [follow us on Twitter](https://x.com/heyglif)
18
+ and come make some cool stuff with us
19
+
20
+ ## Installation
21
+
22
+ Requirements:
23
+ - python >3.10
24
+ - Nvidia GPU with enough ram to do what you need
25
+ - python venv
26
+ - git
27
+
28
+
29
+
30
+ Linux:
31
+ ```bash
32
+ git clone https://github.com/ostris/ai-toolkit.git
33
+ cd ai-toolkit
34
+ git submodule update --init --recursive
35
+ python3 -m venv venv
36
+ source venv/bin/activate
37
+ # .\venv\Scripts\activate on windows
38
+ # install torch first
39
+ pip3 install torch
40
+ pip3 install -r requirements.txt
41
+ ```
42
+
43
+ Windows:
44
+ ```bash
45
+ git clone https://github.com/ostris/ai-toolkit.git
46
+ cd ai-toolkit
47
+ git submodule update --init --recursive
48
+ python -m venv venv
49
+ .\venv\Scripts\activate
50
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ ## FLUX.1 Training
55
+
56
+ ### Tutorial
57
+
58
+ To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
59
+
60
+
61
+ ### Requirements
62
+ You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
63
+ your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
64
+ the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
65
+ but there are some reports of a bug when running on windows natively.
66
+ I have only tested on linux for now. This is still extremely experimental
67
+ and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
68
+
69
+ ### FLUX.1-dev
70
+
71
+ FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
72
+ non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
73
+ Otherwise, this will fail. Here are the required steps to setup a license.
74
+
75
+ 1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
76
+ 2. Make a file named `.env` in the root on this folder
77
+ 3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
78
+
79
+ ### FLUX.1-schnell
80
+
81
+ FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
82
+ However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
83
+ It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
84
+
85
+ To use it, You just need to add the assistant to the `model` section of your config file like so:
86
+
87
+ ```yaml
88
+ model:
89
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
90
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
91
+ is_flux: true
92
+ quantize: true
93
+ ```
94
+
95
+ You also need to adjust your sample steps since schnell does not require as many
96
+
97
+ ```yaml
98
+ sample:
99
+ guidance_scale: 1 # schnell does not do guidance
100
+ sample_steps: 4 # 1 - 4 works well
101
+ ```
102
+
103
+ ### Training
104
+ 1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
105
+ 2. Edit the file following the comments in the file
106
+ 3. Run the file like so `python run.py config/whatever_you_want.yml`
107
+
108
+ A folder with the name and the training folder from the config file will be created when you start. It will have all
109
+ checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
110
+ from the last checkpoint.
111
+
112
+ IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
113
+
114
+ ### Need help?
115
+
116
+ Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
117
+ and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
118
+ and I will answer when I can.
119
+
120
+ ## Gradio UI
121
+
122
+ To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
123
+
124
+ ```bash
125
+ cd ai-toolkit #in case you are not yet in the ai-toolkit folder
126
+ huggingface-cli login #provide a `write` token to publish your LoRA at the end
127
+ python flux_train_ui.py
128
+ ```
129
+
130
+ You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
131
+ ![image](assets/lora_ease_ui.png)
132
+
133
+
134
+ ## Training in RunPod
135
+ Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
136
+ > You need a minimum of 24GB VRAM, pick a GPU by your preference.
137
+
138
+ #### Example config ($0.5/hr):
139
+ - 1x A40 (48 GB VRAM)
140
+ - 19 vCPU 100 GB RAM
141
+
142
+ #### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
143
+ - ~120 GB Disk
144
+ - ~120 GB Pod Volume
145
+ - Start Jupyter Notebook
146
+
147
+ ### 1. Setup
148
+ ```
149
+ git clone https://github.com/ostris/ai-toolkit.git
150
+ cd ai-toolkit
151
+ git submodule update --init --recursive
152
+ python -m venv venv
153
+ source venv/bin/activate
154
+ pip install torch
155
+ pip install -r requirements.txt
156
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
157
+ ```
158
+ ### 2. Upload your dataset
159
+ - Create a new folder in the root, name it `dataset` or whatever you like.
160
+ - Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
161
+
162
+ ### 3. Login into Hugging Face with an Access Token
163
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
164
+ - Run ```huggingface-cli login``` and paste your token.
165
+
166
+ ### 4. Training
167
+ - Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
168
+ - Edit the config following the comments in the file.
169
+ - Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
170
+ - Run the file: ```python run.py config/whatever_you_want.yml```.
171
+
172
+ ### Screenshot from RunPod
173
+ <img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
174
+
175
+ ## Training in Modal
176
+
177
+ ### 1. Setup
178
+ #### ai-toolkit:
179
+ ```
180
+ git clone https://github.com/ostris/ai-toolkit.git
181
+ cd ai-toolkit
182
+ git submodule update --init --recursive
183
+ python -m venv venv
184
+ source venv/bin/activate
185
+ pip install torch
186
+ pip install -r requirements.txt
187
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
188
+ ```
189
+ #### Modal:
190
+ - Run `pip install modal` to install the modal Python package.
191
+ - Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
192
+
193
+ #### Hugging Face:
194
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
195
+ - Run `huggingface-cli login` and paste your token.
196
+
197
+ ### 2. Upload your dataset
198
+ - Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
199
+
200
+ ### 3. Configs
201
+ - Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
202
+ - Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
203
+
204
+ ### 4. Edit run_modal.py
205
+ - Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
206
+
207
+ ```
208
+ code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
209
+ ```
210
+ - Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
211
+
212
+ ### 5. Training
213
+ - Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
214
+ - You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
215
+ - Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
216
+
217
+ ### 6. Saving the model
218
+ - Check contents of the volume by running `modal volume ls flux-lora-models`.
219
+ - Download the content by running `modal volume get flux-lora-models your-model-name`.
220
+ - Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
221
+
222
+ ### Screenshot from Modal
223
+
224
+ <img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
225
+
226
+ ---
227
+
228
+ ## Dataset Preparation
229
+
230
+ Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
231
+ formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
232
+ but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
233
+ You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
234
+ replaced.
235
+
236
+ Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
237
+ The loader will automatically resize them and can handle varying aspect ratios.
238
+
239
+
240
+ ## Training Specific Layers
241
+
242
+ To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
243
+ used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
244
+ network kwargs like so:
245
+
246
+ ```yaml
247
+ network:
248
+ type: "lora"
249
+ linear: 128
250
+ linear_alpha: 128
251
+ network_kwargs:
252
+ only_if_contains:
253
+ - "transformer.single_transformer_blocks.7.proj_out"
254
+ - "transformer.single_transformer_blocks.20.proj_out"
255
+ ```
256
+
257
+ The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
258
+ the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
259
+ For instance to only train the `single_transformer` for FLUX.1, you can use the following:
260
+
261
+ ```yaml
262
+ network:
263
+ type: "lora"
264
+ linear: 128
265
+ linear_alpha: 128
266
+ network_kwargs:
267
+ only_if_contains:
268
+ - "transformer.single_transformer_blocks."
269
+ ```
270
+
271
+ You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
272
+
273
+
274
+ ```yaml
275
+ network:
276
+ type: "lora"
277
+ linear: 128
278
+ linear_alpha: 128
279
+ network_kwargs:
280
+ ignore_if_contains:
281
+ - "transformer.single_transformer_blocks."
282
+ ```
283
+
284
+ `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
285
+ if will be ignored.
286
+
287
+ ---
288
+
289
+ ## EVERYTHING BELOW THIS LINE IS OUTDATED
290
+
291
+ It may still work like that, but I have not tested it in a while.
292
+
293
+ ---
294
+
295
+ ### Batch Image Generation
296
+
297
+ A image generator that can take frompts from a config file or form a txt file and generate them to a
298
+ folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used
299
+ for generat batch image generation.
300
+ It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`.
301
+ Mere info is in the comments in the example
302
+
303
+ ---
304
+
305
+ ### LoRA (lierla), LoCON (LyCORIS) extractor
306
+
307
+ It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
308
+ and LoRA (lierla) support. It can do multiple types of extractions in one run.
309
+ It all runs off a config file, which you can find an example of in `config/examples/extract.example.yml`.
310
+ Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
311
+ Then you can edit the file to your liking. and call it like so:
312
+
313
+ ```bash
314
+ python3 run.py config/whatever_you_want.yml
315
+ ```
316
+
317
+ You can also put a full path to a config file, if you want to keep it somewhere else.
318
+
319
+ ```bash
320
+ python3 run.py "/home/user/whatever_you_want.yml"
321
+ ```
322
+
323
+ More notes on how it works are available in the example config file itself. LoRA and LoCON both support
324
+ extractions of 'fixed', 'threshold', 'ratio', 'quantile'. I'll update what these do and mean later.
325
+ Most people used fixed, which is traditional fixed dimension extraction.
326
+
327
+ `process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
328
+
329
+ ---
330
+
331
+ ### LoRA Rescale
332
+
333
+ Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
334
+ A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
335
+ It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
336
+ Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
337
+ Then you can edit the file to your liking. and call it like so:
338
+
339
+ ```bash
340
+ python3 run.py config/whatever_you_want.yml
341
+ ```
342
+
343
+ You can also put a full path to a config file, if you want to keep it somewhere else.
344
+
345
+ ```bash
346
+ python3 run.py "/home/user/whatever_you_want.yml"
347
+ ```
348
+
349
+ More notes on how it works are available in the example config file itself. This is useful when making
350
+ all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
351
+ or even -15 to 15. This will allow you to dile it in so they all have your desired scale
352
+
353
+ ---
354
+
355
+ ### LoRA Slider Trainer
356
+
357
+ <a target="_blank" href="https://colab.research.google.com/github/ostris/ai-toolkit/blob/main/notebooks/SliderTraining.ipynb">
358
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
359
+ </a>
360
+
361
+ This is how I train most of the recent sliders I have on Civitai, you can check them out in my [Civitai profile](https://civitai.com/user/Ostris/models).
362
+ It is based off the work by [p1atdev/LECO](https://github.com/p1atdev/LECO) and [rohitgandikota/erasing](https://github.com/rohitgandikota/erasing)
363
+ But has been heavily modified to create sliders rather than erasing concepts. I have a lot more plans on this, but it is
364
+ very functional as is. It is also very easy to use. Just copy the example config file in `config/examples/train_slider.example.yml`
365
+ to the `config` folder and rename it to `whatever_you_want.yml`. Then you can edit the file to your liking. and call it like so:
366
+
367
+ ```bash
368
+ python3 run.py config/whatever_you_want.yml
369
+ ```
370
+
371
+ There is a lot more information in that example file. You can even run the example as is without any modifications to see
372
+ how it works. It will create a slider that turns all animals into dogs(neg) or cats(pos). Just run it like so:
373
+
374
+ ```bash
375
+ python3 run.py config/examples/train_slider.example.yml
376
+ ```
377
+
378
+ And you will be able to see how it works without configuring anything. No datasets are required for this method.
379
+ I will post an better tutorial soon.
380
+
381
+ ---
382
+
383
+ ## Extensions!!
384
+
385
+ You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
386
+ available to them. I will probably use this as the primary development method going
387
+ forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
388
+ of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
389
+ folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
390
+ enough to get you started. To make an extension, just copy that example and replace all the things you need to.
391
+
392
+
393
+ ### Model Merger - Example Extension
394
+ It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
395
+ as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
396
+ mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
397
+ example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
398
+ and use it like any other config file.
399
+
400
+ ## WIP Tools
401
+
402
+
403
+ ### VAE (Variational Auto Encoder) Trainer
404
+
405
+ This works, but is not ready for others to use and therefore does not have an example config.
406
+ I am still working on it. I will update this when it is ready.
407
+ I am adding a lot of features for criteria that I have used in my image enlargement work. A Critic (discriminator),
408
+ content loss, style loss, and a few more. If you don't know, the VAE
409
+ for stable diffusion (yes even the MSE one, and SDXL), are horrible at smaller faces and it holds SD back. I will fix this.
410
+ I'll post more about this later with better examples later, but here is a quick test of a run through with various VAEs.
411
+ Just went in and out. It is much worse on smaller faces than shown here.
412
+
413
+ <img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
414
+
415
+ ---
416
+
417
+ ## TODO
418
+ - [X] Add proper regs on sliders
419
+ - [X] Add SDXL support (base model only for now)
420
+ - [ ] Add plain erasing
421
+ - [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
422
+
423
+ ---
424
+
425
+ ## Change Log
426
+
427
+ #### 2023-08-05
428
+ - Huge memory rework and slider rework. Slider training is better thant ever with no more
429
+ ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
430
+ accumulation. This makes it much faster and more stable.
431
+ - Updated the example config to be something more practical and more updated to current methods. It is now
432
+ a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
433
+ 6GB gpu now. Will test soon to verify.
434
+
435
+
436
+ #### 2021-10-20
437
+ - Windows support bug fixes
438
+ - Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
439
+ check out the example in the `extensions` folder. Read more about that above.
440
+ - Model Merging, provided via the example extension.
441
+
442
+ #### 2023-08-03
443
+ Another big refactor to make SD more modular.
444
+
445
+ Made batch image generation script
446
+
447
+ #### 2023-08-01
448
+ Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
449
+ Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
450
+ at the moment, so hopefully there are not breaking changes.
451
+
452
+ Unfortunately, I am too lazy to write a proper changelog with all the changes.
453
+
454
+ I added SDXL training to sliders... but.. it does not work properly.
455
+ The slider training relies on a model's ability to understand that an unconditional (negative prompt)
456
+ means you do not want that concept in the output. SDXL does not understand this for whatever reason,
457
+ which makes separating out
458
+ concepts within the model hard. I am sure the community will find a way to fix this
459
+ over time, but for now, it is not
460
+ going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
461
+ encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
462
+ training without every experimental new paper added to it. The KISS principal.
463
+
464
+
465
+ #### 2023-07-30
466
+ Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
467
+ regularizer. You can set the network multiplier to force spread consistency at high weights
468
+
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from hf_ui import demo
4
+
5
+ if __name__ == "__main__":
6
+ demo.launch()
assets/glif.svg ADDED
assets/lora_ease_ui.png ADDED
build_and_push_docker.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
4
+ # wait 2 seconds
5
+ sleep 2
6
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:latest -f docker/Dockerfile .
7
+ docker tag aitoolkit:latest ostris/aitoolkit:latest
8
+ docker push ostris/aitoolkit:latest
config/examples/extract.example.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # this is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to read and write
4
+ # plus it has comments which is nice for documentation
5
+ job: extract # tells the runner what to do
6
+ config:
7
+ # the name will be used to create a folder in the output folder
8
+ # it will also replace any [name] token in the rest of this config
9
+ name: name_of_your_model
10
+ # can be hugging face model, a .ckpt, or a .safetensors
11
+ base_model: "/path/to/base/model.safetensors"
12
+ # can be hugging face model, a .ckpt, or a .safetensors
13
+ extract_model: "/path/to/model/to/extract/trained.safetensors"
14
+ # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
15
+ output_folder: "/path/to/output/folder"
16
+ is_v2: false
17
+ dtype: fp16 # saved dtype
18
+ device: cpu # cpu, cuda:0, etc
19
+
20
+ # processes can be chained like this to run multiple in a row
21
+ # they must all use same models above, but great for testing different
22
+ # sizes and typed of extractions. It is much faster as we already have the models loaded
23
+ process:
24
+ # process 1
25
+ - type: locon # locon or lora (locon is lycoris)
26
+ filename: "[name]_64_32.safetensors" # will be put in output folder
27
+ dtype: fp16
28
+ mode: fixed
29
+ linear: 64
30
+ conv: 32
31
+
32
+ # process 2
33
+ - type: locon
34
+ output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
35
+ mode: ratio
36
+ linear: 0.2
37
+ conv: 0.2
38
+
39
+ # process 3
40
+ - type: locon
41
+ filename: "[name]_ratio_02.safetensors"
42
+ mode: quantile
43
+ linear: 0.5
44
+ conv: 0.5
45
+
46
+ # process 4
47
+ - type: lora # traditional lora extraction (lierla) with linear layers only
48
+ filename: "[name]_4.safetensors"
49
+ mode: fixed # fixed, ratio, quantile supported for lora as well
50
+ linear: 4 # lora dim or rank
51
+ # no conv for lora
52
+
53
+ # process 5
54
+ - type: lora
55
+ filename: "[name]_q05.safetensors"
56
+ mode: quantile
57
+ linear: 0.5
58
+
59
+ # you can put any information you want here, and it will be saved in the model
60
+ # the below is an example. I recommend doing trigger words at a minimum
61
+ # in the metadata. The software will include this plus some other information
62
+ meta:
63
+ name: "[name]" # [name] gets replaced with the name above
64
+ description: A short description of your model
65
+ trigger_words:
66
+ - put
67
+ - trigger
68
+ - words
69
+ - here
70
+ version: '0.1'
71
+ creator:
72
+ name: Your Name
73
74
+ website: https://yourwebsite.com
75
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/generate.example.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ job: generate # tells the runner what to do
4
+ config:
5
+ name: "generate" # this is not really used anywhere currently but required by runner
6
+ process:
7
+ # process 1
8
+ - type: to_folder # process images to a folder
9
+ output_folder: "output/gen"
10
+ device: cuda:0 # cpu, cuda:0, etc
11
+ generate:
12
+ # these are your defaults you can override most of them with flags
13
+ sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
14
+ width: 1024
15
+ height: 1024
16
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
17
+ seed: -1 # -1 is random
18
+ guidance_scale: 7
19
+ sample_steps: 20
20
+ ext: ".png" # .png, .jpg, .jpeg, .webp
21
+
22
+ # here ate the flags you can use for prompts. Always start with
23
+ # your prompt first then add these flags after. You can use as many
24
+ # like
25
+ # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
26
+ # we will try to support all sd-scripts flags where we can
27
+
28
+ # FROM SD-SCRIPTS
29
+ # --n Treat everything until the next option as a negative prompt.
30
+ # --w Specify the width of the generated image.
31
+ # --h Specify the height of the generated image.
32
+ # --d Specify the seed for the generated image.
33
+ # --l Specify the CFG scale for the generated image.
34
+ # --s Specify the number of steps during generation.
35
+
36
+ # OURS and some QOL additions
37
+ # --p2 Prompt for the second text encoder (SDXL only)
38
+ # --n2 Negative prompt for the second text encoder (SDXL only)
39
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
40
+ # --seed Specify the seed for the generated image same as --d
41
+ # --cfg Specify the CFG scale for the generated image same as --l
42
+ # --steps Specify the number of steps during generation same as --s
43
+
44
+ prompt_file: false # if true a txt file will be created next to images with prompt strings used
45
+ # prompts can also be a path to a text file with one prompt per line
46
+ # prompts: "/path/to/prompts.txt"
47
+ prompts:
48
+ - "photo of batman"
49
+ - "photo of superman"
50
+ - "photo of spiderman"
51
+ - "photo of a superhero --n batman superman spiderman"
52
+
53
+ model:
54
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
55
+ # name_or_path: "runwayml/stable-diffusion-v1-5"
56
+ name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
57
+ is_v2: false # for v2 models
58
+ is_v_pred: false # for v-prediction models (most v2 models)
59
+ is_xl: false # for SDXL models
60
+ dtype: bf16
config/examples/mod_lora_scale.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: mod
3
+ config:
4
+ name: name_of_your_model_v1
5
+ process:
6
+ - type: rescale_lora
7
+ # path to your current lora model
8
+ input_path: "/path/to/lora/lora.safetensors"
9
+ # output path for your new lora model, can be the same as input_path to replace
10
+ output_path: "/path/to/lora/output_lora_v1.safetensors"
11
+ # replaces meta with the meta below (plus minimum meta fields)
12
+ # if false, we will leave the meta alone except for updating hashes (sd-script hashes)
13
+ replace_meta: true
14
+ # how to adjust, we can scale the up_down weights or the alpha
15
+ # up_down is the default and probably the best, they will both net the same outputs
16
+ # would only affect rare NaN cases and maybe merging with old merge tools
17
+ scale_target: 'up_down'
18
+ # precision to save, fp16 is the default and standard
19
+ save_dtype: fp16
20
+ # current_weight is the ideal weight you use as a multiplier when using the lora
21
+ # IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
22
+ # you can do negatives here too if you want to flip the lora
23
+ current_weight: 6.0
24
+ # target_weight is the ideal weight you use as a multiplier when using the lora
25
+ # instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
26
+ # we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
27
+ target_weight: 1.0
28
+
29
+ # base model for the lora
30
+ # this is just used to add meta so automatic111 knows which model it is for
31
+ # assume v1.5 if these are not set
32
+ is_xl: false
33
+ is_v2: false
34
+ meta:
35
+ # this is only used if you set replace_meta to true above
36
+ name: "[name]" # [name] gets replaced with the name above
37
+ description: A short description of your lora
38
+ trigger_words:
39
+ - put
40
+ - trigger
41
+ - words
42
+ - here
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/modal/modal_train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
65
+ # place it like "/root/ai-toolkit/FLUX.1-dev"
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
65
+ # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new bell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_lora_sd35_large_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_sd3l_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
26
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
27
+ # hf_repo_id: your-username/your-model-slug
28
+ # hf_private: true #whether the repo is private or public
29
+ datasets:
30
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
31
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
32
+ # images will automatically be resized and bucketed into the resolution specified
33
+ # on windows, escape back slashes with another backslash so
34
+ # "C:\\path\\to\\images\\folder"
35
+ - folder_path: "/path/to/images/folder"
36
+ caption_ext: "txt"
37
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
38
+ shuffle_tokens: false # shuffle caption order, split by commas
39
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
40
+ resolution: [ 1024 ]
41
+ train:
42
+ batch_size: 1
43
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
44
+ gradient_accumulation_steps: 1
45
+ train_unet: true
46
+ train_text_encoder: false # May not fully work with SD3 yet
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch"
49
+ timestep_type: "linear" # linear or sigmoid
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
57
+ # linear_timesteps: true
58
+
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+
64
+ # will probably need this if gpu supports it for sd3, other dtypes may not work correctly
65
+ dtype: bf16
66
+ model:
67
+ # huggingface model name or path
68
+ name_or_path: "stabilityai/stable-diffusion-3.5-large"
69
+ is_v3: true
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: ""
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
config/examples/train_slider.example.yml ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to write
4
+ # Plus it has comments which is nice for documentation
5
+ # This is the config I use on my sliders, It is solid and tested
6
+ job: train
7
+ config:
8
+ # the name will be used to create a folder in the output folder
9
+ # it will also replace any [name] token in the rest of this config
10
+ name: detail_slider_v1
11
+ # folder will be created with name above in folder below
12
+ # it can be relative to the project root or absolute
13
+ training_folder: "output/LoRA"
14
+ device: cuda:0 # cpu, cuda:0, etc
15
+ # for tensorboard logging, we will make a subfolder for this job
16
+ log_dir: "output/.tensorboard"
17
+ # you can stack processes for other jobs, It is not tested with sliders though
18
+ # just use one for now
19
+ process:
20
+ - type: slider # tells runner to run the slider process
21
+ # network is the LoRA network for a slider, I recommend to leave this be
22
+ network:
23
+ # network type lierla is traditional LoRA that works everywhere, only linear layers
24
+ type: "lierla"
25
+ # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
26
+ linear: 8
27
+ linear_alpha: 4 # Do about half of rank
28
+ # training config
29
+ train:
30
+ # this is also used in sampling. Stick with ddpm unless you know what you are doing
31
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
32
+ # how many steps to train. More is not always better. I rarely go over 1000
33
+ steps: 500
34
+ # I have had good results with 4e-4 to 1e-4 at 500 steps
35
+ lr: 2e-4
36
+ # enables gradient checkpoint, saves vram, leave it on
37
+ gradient_checkpointing: true
38
+ # train the unet. I recommend leaving this true
39
+ train_unet: true
40
+ # train the text encoder. I don't recommend this unless you have a special use case
41
+ # for sliders we are adjusting representation of the concept (unet),
42
+ # not the description of it (text encoder)
43
+ train_text_encoder: false
44
+ # same as from sd-scripts, not fully tested but should speed up training
45
+ min_snr_gamma: 5.0
46
+ # just leave unless you know what you are doing
47
+ # also supports "dadaptation" but set lr to 1 if you use that,
48
+ # but it learns too fast and I don't recommend it
49
+ optimizer: "adamw"
50
+ # only constant for now
51
+ lr_scheduler: "constant"
52
+ # we randomly denoise random num of steps form 1 to this number
53
+ # while training. Just leave it
54
+ max_denoising_steps: 40
55
+ # works great at 1. I do 1 even with my 4090.
56
+ # higher may not work right with newer single batch stacking code anyway
57
+ batch_size: 1
58
+ # bf16 works best if your GPU supports it (modern)
59
+ dtype: bf16 # fp32, bf16, fp16
60
+ # if you have it, use it. It is faster and better
61
+ # torch 2.0 doesnt need xformers anymore, only use if you have lower version
62
+ # xformers: true
63
+ # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
64
+ # although, the way we train sliders is comparative, so it probably won't work anyway
65
+ noise_offset: 0.0
66
+ # noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
67
+
68
+ # the model to train the LoRA network on
69
+ model:
70
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
71
+ name_or_path: "runwayml/stable-diffusion-v1-5"
72
+ is_v2: false # for v2 models
73
+ is_v_pred: false # for v-prediction models (most v2 models)
74
+ # has some issues with the dual text encoder and the way we train sliders
75
+ # it works bit weights need to probably be higher to see it.
76
+ is_xl: false # for SDXL models
77
+
78
+ # saving config
79
+ save:
80
+ dtype: float16 # precision to save. I recommend float16
81
+ save_every: 50 # save every this many steps
82
+ # this will remove step counts more than this number
83
+ # allows you to save more often in case of a crash without filling up your drive
84
+ max_step_saves_to_keep: 2
85
+
86
+ # sampling config
87
+ sample:
88
+ # must match train.noise_scheduler, this is not used here
89
+ # but may be in future and in other processes
90
+ sampler: "ddpm"
91
+ # sample every this many steps
92
+ sample_every: 20
93
+ # image size
94
+ width: 512
95
+ height: 512
96
+ # prompts to use for sampling. Do as many as you want, but it slows down training
97
+ # pick ones that will best represent the concept you are trying to adjust
98
+ # allows some flags after the prompt
99
+ # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
100
+ # slide are good tests. will inherit sample.network_multiplier if not set
101
+ # --n [string] # negative prompt, will inherit sample.neg if not set
102
+ # Only 75 tokens allowed currently
103
+ # I like to do a wide positive and negative spread so I can see a good range and stop
104
+ # early if the network is braking down
105
+ prompts:
106
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
107
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
108
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
109
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
110
+ - "a golden retriever sitting on a leather couch, --m -5"
111
+ - "a golden retriever sitting on a leather couch --m -3"
112
+ - "a golden retriever sitting on a leather couch --m 3"
113
+ - "a golden retriever sitting on a leather couch --m 5"
114
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
115
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
116
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
117
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
118
+ # negative prompt used on all prompts above as default if they don't have one
119
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
120
+ # seed for sampling. 42 is the answer for everything
121
+ seed: 42
122
+ # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
123
+ # will start over on next sample_every so s1 is always seed
124
+ # works well if you use same prompt but want different results
125
+ walk_seed: false
126
+ # cfg scale (4 to 10 is good)
127
+ guidance_scale: 7
128
+ # sampler steps (20 to 30 is good)
129
+ sample_steps: 20
130
+ # default network multiplier for all prompts
131
+ # since we are training a slider, I recommend overriding this with --m [number]
132
+ # in the prompts above to get both sides of the slider
133
+ network_multiplier: 1.0
134
+
135
+ # logging information
136
+ logging:
137
+ log_every: 10 # log every this many steps
138
+ use_wandb: false # not supported yet
139
+ verbose: false # probably done need unless you are debugging
140
+
141
+ # slider training config, best for last
142
+ slider:
143
+ # resolutions to train on. [ width, height ]. This is less important for sliders
144
+ # as we are not teaching the model anything it doesn't already know
145
+ # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
146
+ # and [ 1024, 1024 ] for sd_xl
147
+ # you can do as many as you want here
148
+ resolutions:
149
+ - [ 512, 512 ]
150
+ # - [ 512, 768 ]
151
+ # - [ 768, 768 ]
152
+ # slider training uses 4 combined steps for a single round. This will do it in one gradient
153
+ # step. It is highly optimized and shouldn't take anymore vram than doing without it,
154
+ # since we break down batches for gradient accumulation now. so just leave it on.
155
+ batch_full_slide: true
156
+ # These are the concepts to train on. You can do as many as you want here,
157
+ # but they can conflict outweigh each other. Other than experimenting, I recommend
158
+ # just doing one for good results
159
+ targets:
160
+ # target_class is the base concept we are adjusting the representation of
161
+ # for example, if we are adjusting the representation of a person, we would use "person"
162
+ # if we are adjusting the representation of a cat, we would use "cat" It is not
163
+ # a keyword necessarily but what the model understands the concept to represent.
164
+ # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
165
+ # it is the models base general understanding of the concept and everything it represents
166
+ # you can leave it blank to affect everything. In this example, we are adjusting
167
+ # detail, so we will leave it blank to affect everything
168
+ - target_class: ""
169
+ # positive is the prompt for the positive side of the slider.
170
+ # It is the concept that will be excited and amplified in the model when we slide the slider
171
+ # to the positive side and forgotten / inverted when we slide
172
+ # the slider to the negative side. It is generally best to include the target_class in
173
+ # the prompt. You want it to be the extreme of what you want to train on. For example,
174
+ # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
175
+ # as the prompt. Not just "fat person"
176
+ # max 75 tokens for now
177
+ positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
178
+ # negative is the prompt for the negative side of the slider and works the same as positive
179
+ # it does not necessarily work the same as a negative prompt when generating images
180
+ # these need to be polar opposites.
181
+ # max 76 tokens for now
182
+ negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
183
+ # the loss for this target is multiplied by this number.
184
+ # if you are doing more than one target it may be good to set less important ones
185
+ # to a lower number like 0.1 so they don't outweigh the primary target
186
+ weight: 1.0
187
+ # shuffle the prompts split by the comma. We will run every combination randomly
188
+ # this will make the LoRA more robust. You probably want this on unless prompt order
189
+ # is important for some reason
190
+ shuffle: true
191
+
192
+
193
+ # anchors are prompts that we will try to hold on to while training the slider
194
+ # these are NOT necessary and can prevent the slider from converging if not done right
195
+ # leave them off if you are having issues, but they can help lock the network
196
+ # on certain concepts to help prevent catastrophic forgetting
197
+ # you want these to generate an image that is not your target_class, but close to it
198
+ # is fine as long as it does not directly overlap it.
199
+ # For example, if you are training on a person smiling,
200
+ # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
201
+ # regardless if they are smiling or not, however, the closer the concept is to the target_class
202
+ # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
203
+ # for close concepts, you want to be closer to 0.1 or 0.2
204
+ # these will slow down training. I am leaving them off for the demo
205
+
206
+ # anchors:
207
+ # - prompt: "a woman"
208
+ # neg_prompt: "animal"
209
+ # # the multiplier applied to the LoRA when this is run.
210
+ # # higher will give it more weight but also help keep the lora from collapsing
211
+ # multiplier: 1.0
212
+ # - prompt: "a man"
213
+ # neg_prompt: "animal"
214
+ # multiplier: 1.0
215
+ # - prompt: "a person"
216
+ # neg_prompt: "animal"
217
+ # multiplier: 1.0
218
+
219
+ # You can put any information you want here, and it will be saved in the model.
220
+ # The below is an example, but you can put your grocery list in it if you want.
221
+ # It is saved in the model so be aware of that. The software will include this
222
+ # plus some other information for you automatically
223
+ meta:
224
+ # [name] gets replaced with the name above
225
+ name: "[name]"
226
+ # version: '1.0'
227
+ # creator:
228
+ # name: Your Name
229
+ # email: [email protected]
230
+ # website: https://your.website
docker/Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM runpod/base:0.6.2-cuda12.2.0
2
+
3
+ LABEL authors="jaret"
4
+
5
+ # Install dependencies
6
+ RUN apt-get update
7
+
8
+ WORKDIR /app
9
+ ARG CACHEBUST=1
10
+ RUN git clone https://github.com/ostris/ai-toolkit.git && \
11
+ cd ai-toolkit && \
12
+ git submodule update --init --recursive
13
+
14
+ WORKDIR /app/ai-toolkit
15
+
16
+ RUN ln -s /usr/bin/python3 /usr/bin/python
17
+ RUN python -m pip install -r requirements.txt
18
+
19
+ RUN apt-get install -y tmux nvtop htop
20
+
21
+ RUN pip install jupyterlab
22
+
23
+ # mask workspace
24
+ RUN mkdir /workspace
25
+
26
+
27
+ # symlink app to workspace
28
+ RUN ln -s /app/ai-toolkit /workspace/ai-toolkit
29
+
30
+ WORKDIR /
31
+ CMD ["/start.sh"]
extensions/example/ExampleMergeModels.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from collections import OrderedDict
4
+ from typing import TYPE_CHECKING
5
+ from jobs.process import BaseExtensionProcess
6
+ from toolkit.config_modules import ModelConfig
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ from toolkit.train_tools import get_torch_dtype
9
+ from tqdm import tqdm
10
+
11
+ # Type check imports. Prevents circular imports
12
+ if TYPE_CHECKING:
13
+ from jobs import ExtensionJob
14
+
15
+
16
+ # extend standard config classes to add weight
17
+ class ModelInputConfig(ModelConfig):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.weight = kwargs.get('weight', 1.0)
21
+ # overwrite default dtype unless user specifies otherwise
22
+ # float 32 will give up better precision on the merging functions
23
+ self.dtype: str = kwargs.get('dtype', 'float32')
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ # this is our main class process
32
+ class ExampleMergeModels(BaseExtensionProcess):
33
+ def __init__(
34
+ self,
35
+ process_id: int,
36
+ job: 'ExtensionJob',
37
+ config: OrderedDict
38
+ ):
39
+ super().__init__(process_id, job, config)
40
+ # this is the setup process, do not do process intensive stuff here, just variable setup and
41
+ # checking requirements. This is called before the run() function
42
+ # no loading models or anything like that, it is just for setting up the process
43
+ # all of your process intensive stuff should be done in the run() function
44
+ # config will have everything from the process item in the config file
45
+
46
+ # convince methods exist on BaseProcess to get config values
47
+ # if required is set to true and the value is not found it will throw an error
48
+ # you can pass a default value to get_conf() as well if it was not in the config file
49
+ # as well as a type to cast the value to
50
+ self.save_path = self.get_conf('save_path', required=True)
51
+ self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
52
+ self.device = self.get_conf('device', default='cpu', as_type=torch.device)
53
+
54
+ # build models to merge list
55
+ models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
56
+ # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
57
+ # this way you can add methods to it and it is easier to read and code. There are a lot of
58
+ # inbuilt config classes located in toolkit.config_modules as well
59
+ self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
60
+ # setup is complete. Don't load anything else here, just setup variables and stuff
61
+
62
+ # this is the entire run process be sure to call super().run() first
63
+ def run(self):
64
+ # always call first
65
+ super().run()
66
+ print(f"Running process: {self.__class__.__name__}")
67
+
68
+ # let's adjust our weights first to normalize them so the total is 1.0
69
+ total_weight = sum([model.weight for model in self.models_to_merge])
70
+ weight_adjust = 1.0 / total_weight
71
+ for model in self.models_to_merge:
72
+ model.weight *= weight_adjust
73
+
74
+ output_model: StableDiffusion = None
75
+ # let's do the merge, it is a good idea to use tqdm to show progress
76
+ for model_config in tqdm(self.models_to_merge, desc="Merging models"):
77
+ # setup model class with our helper class
78
+ sd_model = StableDiffusion(
79
+ device=self.device,
80
+ model_config=model_config,
81
+ dtype="float32"
82
+ )
83
+ # load the model
84
+ sd_model.load_model()
85
+
86
+ # adjust the weight of the text encoder
87
+ if isinstance(sd_model.text_encoder, list):
88
+ # sdxl model
89
+ for text_encoder in sd_model.text_encoder:
90
+ for key, value in text_encoder.state_dict().items():
91
+ value *= model_config.weight
92
+ else:
93
+ # normal model
94
+ for key, value in sd_model.text_encoder.state_dict().items():
95
+ value *= model_config.weight
96
+ # adjust the weights of the unet
97
+ for key, value in sd_model.unet.state_dict().items():
98
+ value *= model_config.weight
99
+
100
+ if output_model is None:
101
+ # use this one as the base
102
+ output_model = sd_model
103
+ else:
104
+ # merge the models
105
+ # text encoder
106
+ if isinstance(output_model.text_encoder, list):
107
+ # sdxl model
108
+ for i, text_encoder in enumerate(output_model.text_encoder):
109
+ for key, value in text_encoder.state_dict().items():
110
+ value += sd_model.text_encoder[i].state_dict()[key]
111
+ else:
112
+ # normal model
113
+ for key, value in output_model.text_encoder.state_dict().items():
114
+ value += sd_model.text_encoder.state_dict()[key]
115
+ # unet
116
+ for key, value in output_model.unet.state_dict().items():
117
+ value += sd_model.unet.state_dict()[key]
118
+
119
+ # remove the model to free memory
120
+ del sd_model
121
+ flush()
122
+
123
+ # merge loop is done, let's save the model
124
+ print(f"Saving merged model to {self.save_path}")
125
+ output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
126
+ print(f"Saved merged model to {self.save_path}")
127
+ # do cleanup here
128
+ del output_model
129
+ flush()
extensions/example/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ExampleMergeExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "example_merge_extension"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Example Merge Extension"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ExampleMergeModels import ExampleMergeModels
19
+ return ExampleMergeModels
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ExampleMergeExtension
25
+ ]
extensions/example/config/config.example.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # Always include at least one example config file to show how to use your extension.
3
+ # use plenty of comments so users know how to use it and what everything does
4
+
5
+ # all extensions will use this job name
6
+ job: extension
7
+ config:
8
+ name: 'my_awesome_merge'
9
+ process:
10
+ # Put your example processes here. This will be passed
11
+ # to your extension process in the config argument.
12
+ # the type MUST match your extension uid
13
+ - type: "example_merge_extension"
14
+ # save path for the merged model
15
+ save_path: "output/merge/[name].safetensors"
16
+ # save type
17
+ dtype: fp16
18
+ # device to run it on
19
+ device: cuda:0
20
+ # input models can only be SD1.x and SD2.x models for this example (currently)
21
+ models_to_merge:
22
+ # weights are relative, total weights will be normalized
23
+ # for example. If you have 2 models with weight 1.0, they will
24
+ # both be weighted 0.5. If you have 1 model with weight 1.0 and
25
+ # another with weight 2.0, the first will be weighted 1/3 and the
26
+ # second will be weighted 2/3
27
+ - name_or_path: "input/model1.safetensors"
28
+ weight: 1.0
29
+ - name_or_path: "input/model2.safetensors"
30
+ weight: 1.0
31
+ - name_or_path: "input/model3.safetensors"
32
+ weight: 0.3
33
+ - name_or_path: "input/model4.safetensors"
34
+ weight: 1.0
35
+
36
+
37
+ # you can put any information you want here, and it will be saved in the model
38
+ # the below is an example. I recommend doing trigger words at a minimum
39
+ # in the metadata. The software will include this plus some other information
40
+ meta:
41
+ name: "[name]" # [name] gets replaced with the name above
42
+ description: A short description of your model
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
extensions_built_in/advanced_generator/Img2ImgGenerator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from diffusers import T2IAdapter
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from torch.utils.data import DataLoader
12
+ from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
13
+ from tqdm import tqdm
14
+
15
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
16
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
17
+ from toolkit.sampler import get_sampler
18
+ from toolkit.stable_diffusion_model import StableDiffusion
19
+ import gc
20
+ import torch
21
+ from jobs.process import BaseExtensionProcess
22
+ from toolkit.data_loader import get_dataloader_from_datasets
23
+ from toolkit.train_tools import get_torch_dtype
24
+ from controlnet_aux.midas import MidasDetector
25
+ from diffusers.utils import load_image
26
+ from torchvision.transforms import ToTensor
27
+
28
+
29
+ def flush():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+
34
+
35
+
36
+
37
+ class GenerateConfig:
38
+
39
+ def __init__(self, **kwargs):
40
+ self.prompts: List[str]
41
+ self.sampler = kwargs.get('sampler', 'ddpm')
42
+ self.neg = kwargs.get('neg', '')
43
+ self.seed = kwargs.get('seed', -1)
44
+ self.walk_seed = kwargs.get('walk_seed', False)
45
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
46
+ self.sample_steps = kwargs.get('sample_steps', 20)
47
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
48
+ self.ext = kwargs.get('ext', 'png')
49
+ self.denoise_strength = kwargs.get('denoise_strength', 0.5)
50
+ self.trigger_word = kwargs.get('trigger_word', None)
51
+
52
+
53
+ class Img2ImgGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
59
+ self.device = self.get_conf('device', 'cuda')
60
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
61
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
62
+ self.is_latents_cached = True
63
+ raw_datasets = self.get_conf('datasets', None)
64
+ if raw_datasets is not None and len(raw_datasets) > 0:
65
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
66
+ self.datasets = None
67
+ self.datasets_reg = None
68
+ self.dtype = self.get_conf('dtype', 'float16')
69
+ self.torch_dtype = get_torch_dtype(self.dtype)
70
+ self.params = []
71
+ if raw_datasets is not None and len(raw_datasets) > 0:
72
+ for raw_dataset in raw_datasets:
73
+ dataset = DatasetConfig(**raw_dataset)
74
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
75
+ if not is_caching:
76
+ self.is_latents_cached = False
77
+ if dataset.is_reg:
78
+ if self.datasets_reg is None:
79
+ self.datasets_reg = []
80
+ self.datasets_reg.append(dataset)
81
+ else:
82
+ if self.datasets is None:
83
+ self.datasets = []
84
+ self.datasets.append(dataset)
85
+
86
+ self.progress_bar = None
87
+ self.sd = StableDiffusion(
88
+ device=self.device,
89
+ model_config=self.model_config,
90
+ dtype=self.dtype,
91
+ )
92
+ print(f"Using device {self.device}")
93
+ self.data_loader: DataLoader = None
94
+ self.adapter: T2IAdapter = None
95
+
96
+ def to_pil(self, img):
97
+ # image comes in -1 to 1. convert to a PIL RGB image
98
+ img = (img + 1) / 2
99
+ img = img.clamp(0, 1)
100
+ img = img[0].permute(1, 2, 0).cpu().numpy()
101
+ img = (img * 255).astype(np.uint8)
102
+ image = Image.fromarray(img)
103
+ return image
104
+
105
+ def run(self):
106
+ with torch.no_grad():
107
+ super().run()
108
+ print("Loading model...")
109
+ self.sd.load_model()
110
+ device = torch.device(self.device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLImg2ImgPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ ).to(device, dtype=self.torch_dtype)
122
+ elif self.model_config.is_pixart:
123
+ pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
124
+ else:
125
+ raise NotImplementedError("Only XL models are supported")
126
+ pipe.set_progress_bar_config(disable=True)
127
+
128
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
129
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
130
+
131
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
132
+
133
+ num_batches = len(self.data_loader)
134
+ pbar = tqdm(total=num_batches, desc="Generating images")
135
+ seed = self.generate_config.seed
136
+ # load images from datasets, use tqdm
137
+ for i, batch in enumerate(self.data_loader):
138
+ batch: DataLoaderBatchDTO = batch
139
+
140
+ gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
141
+ generator = torch.manual_seed(gen_seed)
142
+
143
+ file_item: FileItemDTO = batch.file_items[0]
144
+ img_path = file_item.path
145
+ img_filename = os.path.basename(img_path)
146
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
147
+ img_filename = img_filename_no_ext + '.' + self.generate_config.ext
148
+ output_path = os.path.join(self.output_folder, img_filename)
149
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
150
+
151
+ if self.copy_inputs_to is not None:
152
+ output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
153
+ output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
154
+ else:
155
+ output_inputs_path = None
156
+ output_inputs_caption_path = None
157
+
158
+ caption = batch.get_caption_list()[0]
159
+ if self.generate_config.trigger_word is not None:
160
+ caption = caption.replace('[trigger]', self.generate_config.trigger_word)
161
+
162
+ img: torch.Tensor = batch.tensor.clone()
163
+ image = self.to_pil(img)
164
+
165
+ # image.save(output_depth_path)
166
+ if self.model_config.is_pixart:
167
+ pipe: PixArtSigmaPipeline = pipe
168
+
169
+ # Encode the full image once
170
+ encoded_image = pipe.vae.encode(
171
+ pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
172
+ if hasattr(encoded_image, "latent_dist"):
173
+ latents = encoded_image.latent_dist.sample(generator)
174
+ elif hasattr(encoded_image, "latents"):
175
+ latents = encoded_image.latents
176
+ else:
177
+ raise AttributeError("Could not access latents of provided encoder_output")
178
+ latents = pipe.vae.config.scaling_factor * latents
179
+
180
+ # latents = self.sd.encode_images(img)
181
+
182
+ # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
183
+ # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
184
+ # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
185
+ # timestep = timestep.to(device, dtype=torch.int32)
186
+ # latent = latent.to(device, dtype=self.torch_dtype)
187
+ # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
188
+ # latent = self.sd.add_noise(latent, noise, timestep)
189
+ # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
190
+ batch_size = 1
191
+ num_images_per_prompt = 1
192
+
193
+ shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
194
+ image.width // pipe.vae_scale_factor)
195
+ noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
196
+
197
+ # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
198
+ num_inference_steps = self.generate_config.sample_steps
199
+ strength = self.generate_config.denoise_strength
200
+ # Get timesteps
201
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
202
+ t_start = max(num_inference_steps - init_timestep, 0)
203
+ pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
204
+ timesteps = pipe.scheduler.timesteps[t_start:]
205
+ timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
206
+ latents = pipe.scheduler.add_noise(latents, noise, timestep)
207
+
208
+ gen_images = pipe.__call__(
209
+ prompt=caption,
210
+ negative_prompt=self.generate_config.neg,
211
+ latents=latents,
212
+ timesteps=timesteps,
213
+ width=image.width,
214
+ height=image.height,
215
+ num_inference_steps=num_inference_steps,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ guidance_scale=self.generate_config.guidance_scale,
218
+ # strength=self.generate_config.denoise_strength,
219
+ use_resolution_binning=False,
220
+ output_type="np"
221
+ ).images[0]
222
+ gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
223
+ gen_images = Image.fromarray(gen_images)
224
+ else:
225
+ pipe: StableDiffusionXLImg2ImgPipeline = pipe
226
+
227
+ gen_images = pipe.__call__(
228
+ prompt=caption,
229
+ negative_prompt=self.generate_config.neg,
230
+ image=image,
231
+ num_inference_steps=self.generate_config.sample_steps,
232
+ guidance_scale=self.generate_config.guidance_scale,
233
+ strength=self.generate_config.denoise_strength,
234
+ ).images[0]
235
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
236
+ gen_images.save(output_path)
237
+
238
+ # save caption
239
+ with open(output_caption_path, 'w') as f:
240
+ f.write(caption)
241
+
242
+ if output_inputs_path is not None:
243
+ os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
244
+ image.save(output_inputs_path)
245
+ with open(output_inputs_caption_path, 'w') as f:
246
+ f.write(caption)
247
+
248
+ pbar.update(1)
249
+ batch.cleanup()
250
+
251
+ pbar.close()
252
+ print("Done generating images")
253
+ # cleanup
254
+ del self.sd
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/PureLoraGenerator.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
5
+ from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
6
+ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ import gc
9
+ import torch
10
+ from jobs.process import BaseExtensionProcess
11
+ from toolkit.train_tools import get_torch_dtype
12
+
13
+
14
+ def flush():
15
+ torch.cuda.empty_cache()
16
+ gc.collect()
17
+
18
+
19
+ class PureLoraGenerator(BaseExtensionProcess):
20
+
21
+ def __init__(self, process_id: int, job, config: OrderedDict):
22
+ super().__init__(process_id, job, config)
23
+ self.output_folder = self.get_conf('output_folder', required=True)
24
+ self.device = self.get_conf('device', 'cuda')
25
+ self.device_torch = torch.device(self.device)
26
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
27
+ self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
28
+ self.dtype = self.get_conf('dtype', 'float16')
29
+ self.torch_dtype = get_torch_dtype(self.dtype)
30
+ lorm_config = self.get_conf('lorm', None)
31
+ self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
32
+
33
+ self.device_state_preset = get_train_sd_device_state_preset(
34
+ device=torch.device(self.device),
35
+ )
36
+
37
+ self.progress_bar = None
38
+ self.sd = StableDiffusion(
39
+ device=self.device,
40
+ model_config=self.model_config,
41
+ dtype=self.dtype,
42
+ )
43
+
44
+ def run(self):
45
+ super().run()
46
+ print("Loading model...")
47
+ with torch.no_grad():
48
+ self.sd.load_model()
49
+ self.sd.unet.eval()
50
+ self.sd.unet.to(self.device_torch)
51
+ if isinstance(self.sd.text_encoder, list):
52
+ for te in self.sd.text_encoder:
53
+ te.eval()
54
+ te.to(self.device_torch)
55
+ else:
56
+ self.sd.text_encoder.eval()
57
+ self.sd.to(self.device_torch)
58
+
59
+ print(f"Converting to LoRM UNet")
60
+ # replace the unet with LoRMUnet
61
+ convert_diffusers_unet_to_lorm(
62
+ self.sd.unet,
63
+ config=self.lorm_config,
64
+ )
65
+
66
+ sample_folder = os.path.join(self.output_folder)
67
+ gen_img_config_list = []
68
+
69
+ sample_config = self.generate_config
70
+ start_seed = sample_config.seed
71
+ current_seed = start_seed
72
+ for i in range(len(sample_config.prompts)):
73
+ if sample_config.walk_seed:
74
+ current_seed = start_seed + i
75
+
76
+ filename = f"[time]_[count].{self.generate_config.ext}"
77
+ output_path = os.path.join(sample_folder, filename)
78
+ prompt = sample_config.prompts[i]
79
+ extra_args = {}
80
+ gen_img_config_list.append(GenerateImageConfig(
81
+ prompt=prompt, # it will autoparse the prompt
82
+ width=sample_config.width,
83
+ height=sample_config.height,
84
+ negative_prompt=sample_config.neg,
85
+ seed=current_seed,
86
+ guidance_scale=sample_config.guidance_scale,
87
+ guidance_rescale=sample_config.guidance_rescale,
88
+ num_inference_steps=sample_config.sample_steps,
89
+ network_multiplier=sample_config.network_multiplier,
90
+ output_path=output_path,
91
+ output_ext=sample_config.ext,
92
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
93
+ **extra_args
94
+ ))
95
+
96
+ # send to be generated
97
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
98
+ print("Done generating images")
99
+ # cleanup
100
+ del self.sd
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/ReferenceGenerator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from diffusers import T2IAdapter
9
+ from torch.utils.data import DataLoader
10
+ from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
11
+ from tqdm import tqdm
12
+
13
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
14
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
15
+ from toolkit.sampler import get_sampler
16
+ from toolkit.stable_diffusion_model import StableDiffusion
17
+ import gc
18
+ import torch
19
+ from jobs.process import BaseExtensionProcess
20
+ from toolkit.data_loader import get_dataloader_from_datasets
21
+ from toolkit.train_tools import get_torch_dtype
22
+ from controlnet_aux.midas import MidasDetector
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ class GenerateConfig:
32
+
33
+ def __init__(self, **kwargs):
34
+ self.prompts: List[str]
35
+ self.sampler = kwargs.get('sampler', 'ddpm')
36
+ self.neg = kwargs.get('neg', '')
37
+ self.seed = kwargs.get('seed', -1)
38
+ self.walk_seed = kwargs.get('walk_seed', False)
39
+ self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
40
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
41
+ self.sample_steps = kwargs.get('sample_steps', 20)
42
+ self.prompt_2 = kwargs.get('prompt_2', None)
43
+ self.neg_2 = kwargs.get('neg_2', None)
44
+ self.prompts = kwargs.get('prompts', None)
45
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
46
+ self.ext = kwargs.get('ext', 'png')
47
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
48
+ if kwargs.get('shuffle', False):
49
+ # shuffle the prompts
50
+ random.shuffle(self.prompts)
51
+
52
+
53
+ class ReferenceGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.device = self.get_conf('device', 'cuda')
59
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
60
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
61
+ self.is_latents_cached = True
62
+ raw_datasets = self.get_conf('datasets', None)
63
+ if raw_datasets is not None and len(raw_datasets) > 0:
64
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
65
+ self.datasets = None
66
+ self.datasets_reg = None
67
+ self.dtype = self.get_conf('dtype', 'float16')
68
+ self.torch_dtype = get_torch_dtype(self.dtype)
69
+ self.params = []
70
+ if raw_datasets is not None and len(raw_datasets) > 0:
71
+ for raw_dataset in raw_datasets:
72
+ dataset = DatasetConfig(**raw_dataset)
73
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
74
+ if not is_caching:
75
+ self.is_latents_cached = False
76
+ if dataset.is_reg:
77
+ if self.datasets_reg is None:
78
+ self.datasets_reg = []
79
+ self.datasets_reg.append(dataset)
80
+ else:
81
+ if self.datasets is None:
82
+ self.datasets = []
83
+ self.datasets.append(dataset)
84
+
85
+ self.progress_bar = None
86
+ self.sd = StableDiffusion(
87
+ device=self.device,
88
+ model_config=self.model_config,
89
+ dtype=self.dtype,
90
+ )
91
+ print(f"Using device {self.device}")
92
+ self.data_loader: DataLoader = None
93
+ self.adapter: T2IAdapter = None
94
+
95
+ def run(self):
96
+ super().run()
97
+ print("Loading model...")
98
+ self.sd.load_model()
99
+ device = torch.device(self.device)
100
+
101
+ if self.generate_config.t2i_adapter_path is not None:
102
+ self.adapter = T2IAdapter.from_pretrained(
103
+ self.generate_config.t2i_adapter_path,
104
+ torch_dtype=self.torch_dtype,
105
+ varient="fp16"
106
+ ).to(device)
107
+
108
+ midas_depth = MidasDetector.from_pretrained(
109
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
110
+ ).to(device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLAdapterPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ adapter=self.adapter,
122
+ ).to(device, dtype=self.torch_dtype)
123
+ else:
124
+ pipe = StableDiffusionAdapterPipeline(
125
+ vae=self.sd.vae,
126
+ unet=self.sd.unet,
127
+ text_encoder=self.sd.text_encoder,
128
+ tokenizer=self.sd.tokenizer,
129
+ scheduler=get_sampler(self.generate_config.sampler),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ adapter=self.adapter,
134
+ ).to(device, dtype=self.torch_dtype)
135
+ pipe.set_progress_bar_config(disable=True)
136
+
137
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
138
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
139
+
140
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
141
+
142
+ num_batches = len(self.data_loader)
143
+ pbar = tqdm(total=num_batches, desc="Generating images")
144
+ seed = self.generate_config.seed
145
+ # load images from datasets, use tqdm
146
+ for i, batch in enumerate(self.data_loader):
147
+ batch: DataLoaderBatchDTO = batch
148
+
149
+ file_item: FileItemDTO = batch.file_items[0]
150
+ img_path = file_item.path
151
+ img_filename = os.path.basename(img_path)
152
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
153
+ output_path = os.path.join(self.output_folder, img_filename)
154
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
155
+ output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
156
+
157
+ caption = batch.get_caption_list()[0]
158
+
159
+ img: torch.Tensor = batch.tensor.clone()
160
+ # image comes in -1 to 1. convert to a PIL RGB image
161
+ img = (img + 1) / 2
162
+ img = img.clamp(0, 1)
163
+ img = img[0].permute(1, 2, 0).cpu().numpy()
164
+ img = (img * 255).astype(np.uint8)
165
+ image = Image.fromarray(img)
166
+
167
+ width, height = image.size
168
+ min_res = min(width, height)
169
+
170
+ if self.generate_config.walk_seed:
171
+ seed = seed + 1
172
+
173
+ if self.generate_config.seed == -1:
174
+ # random
175
+ seed = random.randint(0, 1000000)
176
+
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+
180
+ # generate depth map
181
+ image = midas_depth(
182
+ image,
183
+ detect_resolution=min_res, # do 512 ?
184
+ image_resolution=min_res
185
+ )
186
+
187
+ # image.save(output_depth_path)
188
+
189
+ gen_images = pipe(
190
+ prompt=caption,
191
+ negative_prompt=self.generate_config.neg,
192
+ image=image,
193
+ num_inference_steps=self.generate_config.sample_steps,
194
+ adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
195
+ guidance_scale=self.generate_config.guidance_scale,
196
+ ).images[0]
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+ gen_images.save(output_path)
199
+
200
+ # save caption
201
+ with open(output_caption_path, 'w') as f:
202
+ f.write(caption)
203
+
204
+ pbar.update(1)
205
+ batch.cleanup()
206
+
207
+ pbar.close()
208
+ print("Done generating images")
209
+ # cleanup
210
+ del self.sd
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class AdvancedReferenceGeneratorExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "reference_generator"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Reference Generator"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ReferenceGenerator import ReferenceGenerator
19
+ return ReferenceGenerator
20
+
21
+
22
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
23
+ class PureLoraGenerator(Extension):
24
+ # uid must be unique, it is how the extension is identified
25
+ uid = "pure_lora_generator"
26
+
27
+ # name is the name of the extension for printing
28
+ name = "Pure LoRA Generator"
29
+
30
+ # This is where your process class is loaded
31
+ # keep your imports in here so they don't slow down the rest of the program
32
+ @classmethod
33
+ def get_process(cls):
34
+ # import your process class here so it is only loaded when needed and return it
35
+ from .PureLoraGenerator import PureLoraGenerator
36
+ return PureLoraGenerator
37
+
38
+
39
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
40
+ class Img2ImgGeneratorExtension(Extension):
41
+ # uid must be unique, it is how the extension is identified
42
+ uid = "batch_img2img"
43
+
44
+ # name is the name of the extension for printing
45
+ name = "Img2ImgGeneratorExtension"
46
+
47
+ # This is where your process class is loaded
48
+ # keep your imports in here so they don't slow down the rest of the program
49
+ @classmethod
50
+ def get_process(cls):
51
+ # import your process class here so it is only loaded when needed and return it
52
+ from .Img2ImgGenerator import Img2ImgGenerator
53
+ return Img2ImgGenerator
54
+
55
+
56
+ AI_TOOLKIT_EXTENSIONS = [
57
+ # you can put a list of extensions here
58
+ AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
59
+ ]
extensions_built_in/advanced_generator/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/concept_replacer/ConceptReplacer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import OrderedDict
3
+ from torch.utils.data import DataLoader
4
+ from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
5
+ from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
6
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
7
+ import gc
8
+ import torch
9
+ from jobs.process import BaseSDTrainProcess
10
+
11
+
12
+ def flush():
13
+ torch.cuda.empty_cache()
14
+ gc.collect()
15
+
16
+
17
+ class ConceptReplacementConfig:
18
+ def __init__(self, **kwargs):
19
+ self.concept: str = kwargs.get('concept', '')
20
+ self.replacement: str = kwargs.get('replacement', '')
21
+
22
+
23
+ class ConceptReplacer(BaseSDTrainProcess):
24
+
25
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
26
+ super().__init__(process_id, job, config, **kwargs)
27
+ replacement_list = self.config.get('replacements', [])
28
+ self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
29
+
30
+ def before_model_load(self):
31
+ pass
32
+
33
+ def hook_before_train_loop(self):
34
+ self.sd.vae.eval()
35
+ self.sd.vae.to(self.device_torch)
36
+
37
+ # textual inversion
38
+ if self.embedding is not None:
39
+ # set text encoder to train. Not sure if this is necessary but diffusers example did it
40
+ self.sd.text_encoder.train()
41
+
42
+ def hook_train_loop(self, batch):
43
+ with torch.no_grad():
44
+ dtype = get_torch_dtype(self.train_config.dtype)
45
+ noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
46
+ network_weight_list = batch.get_network_weight_list()
47
+
48
+ # have a blank network so we can wrap it in a context and set multipliers without checking every time
49
+ if self.network is not None:
50
+ network = self.network
51
+ else:
52
+ network = BlankNetwork()
53
+
54
+ batch_replacement_list = []
55
+ # get a random replacement for each prompt
56
+ for prompt in conditioned_prompts:
57
+ replacement = random.choice(self.replacement_list)
58
+ batch_replacement_list.append(replacement)
59
+
60
+ # build out prompts
61
+ concept_prompts = []
62
+ replacement_prompts = []
63
+ for idx, replacement in enumerate(batch_replacement_list):
64
+ prompt = conditioned_prompts[idx]
65
+
66
+ # insert shuffled concept at beginning and end of prompt
67
+ shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
68
+ random.shuffle(shuffled_concept)
69
+ shuffled_concept = ', '.join(shuffled_concept)
70
+ concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
71
+
72
+ # insert replacement at beginning and end of prompt
73
+ shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
74
+ random.shuffle(shuffled_replacement)
75
+ shuffled_replacement = ', '.join(shuffled_replacement)
76
+ replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
77
+
78
+ # predict the replacement without network
79
+ conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
80
+
81
+ replacement_pred = self.sd.predict_noise(
82
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
83
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
84
+ timestep=timesteps,
85
+ guidance_scale=1.0,
86
+ )
87
+
88
+ del conditional_embeds
89
+ replacement_pred = replacement_pred.detach()
90
+
91
+ self.optimizer.zero_grad()
92
+ flush()
93
+
94
+ # text encoding
95
+ grad_on_text_encoder = False
96
+ if self.train_config.train_text_encoder:
97
+ grad_on_text_encoder = True
98
+
99
+ if self.embedding:
100
+ grad_on_text_encoder = True
101
+
102
+ # set the weights
103
+ network.multiplier = network_weight_list
104
+
105
+ # activate network if it exits
106
+ with network:
107
+ with torch.set_grad_enabled(grad_on_text_encoder):
108
+ # embed the prompts
109
+ conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
110
+ if not grad_on_text_encoder:
111
+ # detach the embeddings
112
+ conditional_embeds = conditional_embeds.detach()
113
+ self.optimizer.zero_grad()
114
+ flush()
115
+
116
+ noise_pred = self.sd.predict_noise(
117
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
118
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
119
+ timestep=timesteps,
120
+ guidance_scale=1.0,
121
+ )
122
+
123
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
124
+ loss = loss.mean([1, 2, 3])
125
+
126
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
127
+ # add min_snr_gamma
128
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
129
+
130
+ loss = loss.mean()
131
+
132
+ # back propagate loss to free ram
133
+ loss.backward()
134
+ flush()
135
+
136
+ # apply gradients
137
+ self.optimizer.step()
138
+ self.optimizer.zero_grad()
139
+ self.lr_scheduler.step()
140
+
141
+ if self.embedding is not None:
142
+ # Let's make sure we don't update any embedding weights besides the newly added token
143
+ self.embedding.restore_embeddings()
144
+
145
+ loss_dict = OrderedDict(
146
+ {'loss': loss.item()}
147
+ )
148
+ # reset network multiplier
149
+ network.multiplier = 1.0
150
+
151
+ return loss_dict
extensions_built_in/concept_replacer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class ConceptReplacerExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "concept_replacer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Concept Replacer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ConceptReplacer import ConceptReplacer
19
+ return ConceptReplacer
20
+
21
+
22
+
23
+ AI_TOOLKIT_EXTENSIONS = [
24
+ # you can put a list of extensions here
25
+ ConceptReplacerExtension,
26
+ ]
extensions_built_in/concept_replacer/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/dataset_tools/DatasetTools.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import gc
3
+ import torch
4
+ from jobs.process import BaseExtensionProcess
5
+
6
+
7
+ def flush():
8
+ torch.cuda.empty_cache()
9
+ gc.collect()
10
+
11
+
12
+ class DatasetTools(BaseExtensionProcess):
13
+
14
+ def __init__(self, process_id: int, job, config: OrderedDict):
15
+ super().__init__(process_id, job, config)
16
+
17
+ def run(self):
18
+ super().run()
19
+
20
+ raise NotImplementedError("This extension is not yet implemented")
extensions_built_in/dataset_tools/SuperTagger.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from collections import OrderedDict
5
+ import gc
6
+ import traceback
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from tqdm import tqdm
10
+
11
+ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
12
+ from .tools.fuyu_utils import FuyuImageProcessor
13
+ from .tools.image_tools import load_image, ImageProcessor, resize_to_max
14
+ from .tools.llava_utils import LLaVAImageProcessor
15
+ from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
16
+ from jobs.process import BaseExtensionProcess
17
+ from .tools.sync_tools import get_img_paths
18
+
19
+ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
20
+
21
+
22
+ def flush():
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+
26
+
27
+ VERSION = 2
28
+
29
+
30
+ class SuperTagger(BaseExtensionProcess):
31
+
32
+ def __init__(self, process_id: int, job, config: OrderedDict):
33
+ super().__init__(process_id, job, config)
34
+ parent_dir = config.get('parent_dir', None)
35
+ self.dataset_paths: list[str] = config.get('dataset_paths', [])
36
+ self.device = config.get('device', 'cuda')
37
+ self.steps: list[Step] = config.get('steps', [])
38
+ self.caption_method = config.get('caption_method', 'llava:default')
39
+ self.caption_prompt = config.get('caption_prompt', default_long_prompt)
40
+ self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
41
+ self.force_reprocess_img = config.get('force_reprocess_img', False)
42
+ self.caption_replacements = config.get('caption_replacements', default_replacements)
43
+ self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
44
+ self.master_dataset_dict = OrderedDict()
45
+ self.dataset_master_config_file = config.get('dataset_master_config_file', None)
46
+ if parent_dir is not None and len(self.dataset_paths) == 0:
47
+ # find all folders in the patent_dataset_path
48
+ self.dataset_paths = [
49
+ os.path.join(parent_dir, folder)
50
+ for folder in os.listdir(parent_dir)
51
+ if os.path.isdir(os.path.join(parent_dir, folder))
52
+ ]
53
+ else:
54
+ # make sure they exist
55
+ for dataset_path in self.dataset_paths:
56
+ if not os.path.exists(dataset_path):
57
+ raise ValueError(f"Dataset path does not exist: {dataset_path}")
58
+
59
+ print(f"Found {len(self.dataset_paths)} dataset paths")
60
+
61
+ self.image_processor: ImageProcessor = self.get_image_processor()
62
+
63
+ def get_image_processor(self):
64
+ if self.caption_method.startswith('llava'):
65
+ return LLaVAImageProcessor(device=self.device)
66
+ elif self.caption_method.startswith('fuyu'):
67
+ return FuyuImageProcessor(device=self.device)
68
+ else:
69
+ raise ValueError(f"Unknown caption method: {self.caption_method}")
70
+
71
+ def process_image(self, img_path: str):
72
+ root_img_dir = os.path.dirname(os.path.dirname(img_path))
73
+ filename = os.path.basename(img_path)
74
+ filename_no_ext = os.path.splitext(filename)[0]
75
+ train_dir = os.path.join(root_img_dir, TRAIN_DIR)
76
+ train_img_path = os.path.join(train_dir, filename)
77
+ json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
78
+
79
+ # check if json exists, if it does load it as image info
80
+ if os.path.exists(json_path):
81
+ with open(json_path, 'r') as f:
82
+ img_info = ImgInfo(**json.load(f))
83
+ else:
84
+ img_info = ImgInfo()
85
+
86
+ # always send steps first in case other processes need them
87
+ img_info.add_steps(copy.deepcopy(self.steps))
88
+ img_info.set_version(VERSION)
89
+ img_info.set_caption_method(self.caption_method)
90
+
91
+ image: Image = None
92
+ caption_image: Image = None
93
+
94
+ did_update_image = False
95
+
96
+ # trigger reprocess of steps
97
+ if self.force_reprocess_img:
98
+ img_info.trigger_image_reprocess()
99
+
100
+ # set the image as updated if it does not exist on disk
101
+ if not os.path.exists(train_img_path):
102
+ did_update_image = True
103
+ image = load_image(img_path)
104
+ if img_info.force_image_process:
105
+ did_update_image = True
106
+ image = load_image(img_path)
107
+
108
+ # go through the needed steps
109
+ for step in copy.deepcopy(img_info.state.steps_to_complete):
110
+ if step == 'caption':
111
+ # load image
112
+ if image is None:
113
+ image = load_image(img_path)
114
+ if caption_image is None:
115
+ caption_image = resize_to_max(image, 1024, 1024)
116
+
117
+ if not self.image_processor.is_loaded:
118
+ print('Loading Model. Takes a while, especially the first time')
119
+ self.image_processor.load_model()
120
+
121
+ img_info.caption = self.image_processor.generate_caption(
122
+ image=caption_image,
123
+ prompt=self.caption_prompt,
124
+ replacements=self.caption_replacements
125
+ )
126
+ img_info.mark_step_complete(step)
127
+ elif step == 'caption_short':
128
+ # load image
129
+ if image is None:
130
+ image = load_image(img_path)
131
+
132
+ if caption_image is None:
133
+ caption_image = resize_to_max(image, 1024, 1024)
134
+
135
+ if not self.image_processor.is_loaded:
136
+ print('Loading Model. Takes a while, especially the first time')
137
+ self.image_processor.load_model()
138
+ img_info.caption_short = self.image_processor.generate_caption(
139
+ image=caption_image,
140
+ prompt=self.caption_short_prompt,
141
+ replacements=self.caption_short_replacements
142
+ )
143
+ img_info.mark_step_complete(step)
144
+ elif step == 'contrast_stretch':
145
+ # load image
146
+ if image is None:
147
+ image = load_image(img_path)
148
+ image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
149
+ did_update_image = True
150
+ img_info.mark_step_complete(step)
151
+ else:
152
+ raise ValueError(f"Unknown step: {step}")
153
+
154
+ os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
155
+ if did_update_image:
156
+ image.save(train_img_path)
157
+
158
+ if img_info.is_dirty:
159
+ with open(json_path, 'w') as f:
160
+ json.dump(img_info.to_dict(), f, indent=4)
161
+
162
+ if self.dataset_master_config_file:
163
+ # add to master dict
164
+ self.master_dataset_dict[train_img_path] = img_info.to_dict()
165
+
166
+ def run(self):
167
+ super().run()
168
+ imgs_to_process = []
169
+ # find all images
170
+ for dataset_path in self.dataset_paths:
171
+ raw_dir = os.path.join(dataset_path, RAW_DIR)
172
+ raw_image_paths = get_img_paths(raw_dir)
173
+ for raw_image_path in raw_image_paths:
174
+ imgs_to_process.append(raw_image_path)
175
+
176
+ if len(imgs_to_process) == 0:
177
+ print(f"No images to process")
178
+ else:
179
+ print(f"Found {len(imgs_to_process)} to process")
180
+
181
+ for img_path in tqdm(imgs_to_process, desc="Processing images"):
182
+ try:
183
+ self.process_image(img_path)
184
+ except Exception:
185
+ # print full stack trace
186
+ print(traceback.format_exc())
187
+ continue
188
+ # self.process_image(img_path)
189
+
190
+ if self.dataset_master_config_file is not None:
191
+ # save it as json
192
+ with open(self.dataset_master_config_file, 'w') as f:
193
+ json.dump(self.master_dataset_dict, f, indent=4)
194
+
195
+ del self.image_processor
196
+ flush()
extensions_built_in/dataset_tools/SyncFromCollection.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from collections import OrderedDict
4
+ import gc
5
+ from typing import List
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
11
+ from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
12
+ get_img_paths
13
+ from jobs.process import BaseExtensionProcess
14
+
15
+
16
+ def flush():
17
+ torch.cuda.empty_cache()
18
+ gc.collect()
19
+
20
+
21
+ class SyncFromCollection(BaseExtensionProcess):
22
+
23
+ def __init__(self, process_id: int, job, config: OrderedDict):
24
+ super().__init__(process_id, job, config)
25
+
26
+ self.min_width = config.get('min_width', 1024)
27
+ self.min_height = config.get('min_height', 1024)
28
+
29
+ # add our min_width and min_height to each dataset config if they don't exist
30
+ for dataset_config in config.get('dataset_sync', []):
31
+ if 'min_width' not in dataset_config:
32
+ dataset_config['min_width'] = self.min_width
33
+ if 'min_height' not in dataset_config:
34
+ dataset_config['min_height'] = self.min_height
35
+
36
+ self.dataset_configs: List[DatasetSyncCollectionConfig] = [
37
+ DatasetSyncCollectionConfig(**dataset_config)
38
+ for dataset_config in config.get('dataset_sync', [])
39
+ ]
40
+ print(f"Found {len(self.dataset_configs)} dataset configs")
41
+
42
+ def move_new_images(self, root_dir: str):
43
+ raw_dir = os.path.join(root_dir, RAW_DIR)
44
+ new_dir = os.path.join(root_dir, NEW_DIR)
45
+ new_images = get_img_paths(new_dir)
46
+
47
+ for img_path in new_images:
48
+ # move to raw
49
+ new_path = os.path.join(raw_dir, os.path.basename(img_path))
50
+ shutil.move(img_path, new_path)
51
+
52
+ # remove new dir
53
+ shutil.rmtree(new_dir)
54
+
55
+ def sync_dataset(self, config: DatasetSyncCollectionConfig):
56
+ if config.host == 'unsplash':
57
+ get_images = get_unsplash_images
58
+ elif config.host == 'pexels':
59
+ get_images = get_pexels_images
60
+ else:
61
+ raise ValueError(f"Unknown host: {config.host}")
62
+
63
+ results = {
64
+ 'num_downloaded': 0,
65
+ 'num_skipped': 0,
66
+ 'bad': 0,
67
+ 'total': 0,
68
+ }
69
+
70
+ photos = get_images(config)
71
+ raw_dir = os.path.join(config.directory, RAW_DIR)
72
+ new_dir = os.path.join(config.directory, NEW_DIR)
73
+ raw_images = get_local_image_file_names(raw_dir)
74
+ new_images = get_local_image_file_names(new_dir)
75
+
76
+ for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
77
+ try:
78
+ if photo.filename not in raw_images and photo.filename not in new_images:
79
+ download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
80
+ results['num_downloaded'] += 1
81
+ else:
82
+ results['num_skipped'] += 1
83
+ except Exception as e:
84
+ print(f" - BAD({photo.id}): {e}")
85
+ results['bad'] += 1
86
+ continue
87
+ results['total'] += 1
88
+
89
+ return results
90
+
91
+ def print_results(self, results):
92
+ print(
93
+ f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
94
+
95
+ def run(self):
96
+ super().run()
97
+ print(f"Syncing {len(self.dataset_configs)} datasets")
98
+ all_results = None
99
+ failed_datasets = []
100
+ for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
101
+ try:
102
+ results = self.sync_dataset(dataset_config)
103
+ if all_results is None:
104
+ all_results = {**results}
105
+ else:
106
+ for key, value in results.items():
107
+ all_results[key] += value
108
+
109
+ self.print_results(results)
110
+ except Exception as e:
111
+ print(f" - FAILED: {e}")
112
+ if 'response' in e.__dict__:
113
+ error = f"{e.response.status_code}: {e.response.text}"
114
+ print(f" - {error}")
115
+ failed_datasets.append({'dataset': dataset_config, 'error': error})
116
+ else:
117
+ failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
118
+ continue
119
+
120
+ print("Moving new images to raw")
121
+ for dataset_config in self.dataset_configs:
122
+ self.move_new_images(dataset_config.directory)
123
+
124
+ print("Done syncing datasets")
125
+ self.print_results(all_results)
126
+
127
+ if len(failed_datasets) > 0:
128
+ print(f"Failed to sync {len(failed_datasets)} datasets")
129
+ for failed in failed_datasets:
130
+ print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
131
+ print(f" - ERR: {failed['error']}")
extensions_built_in/dataset_tools/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolkit.extension import Extension
2
+
3
+
4
+ class DatasetToolsExtension(Extension):
5
+ uid = "dataset_tools"
6
+
7
+ # name is the name of the extension for printing
8
+ name = "Dataset Tools"
9
+
10
+ # This is where your process class is loaded
11
+ # keep your imports in here so they don't slow down the rest of the program
12
+ @classmethod
13
+ def get_process(cls):
14
+ # import your process class here so it is only loaded when needed and return it
15
+ from .DatasetTools import DatasetTools
16
+ return DatasetTools
17
+
18
+
19
+ class SyncFromCollectionExtension(Extension):
20
+ uid = "sync_from_collection"
21
+ name = "Sync from Collection"
22
+
23
+ @classmethod
24
+ def get_process(cls):
25
+ # import your process class here so it is only loaded when needed and return it
26
+ from .SyncFromCollection import SyncFromCollection
27
+ return SyncFromCollection
28
+
29
+
30
+ class SuperTaggerExtension(Extension):
31
+ uid = "super_tagger"
32
+ name = "Super Tagger"
33
+
34
+ @classmethod
35
+ def get_process(cls):
36
+ # import your process class here so it is only loaded when needed and return it
37
+ from .SuperTagger import SuperTagger
38
+ return SuperTagger
39
+
40
+
41
+ AI_TOOLKIT_EXTENSIONS = [
42
+ SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
43
+ ]
extensions_built_in/dataset_tools/tools/caption.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ caption_manipulation_steps = ['caption', 'caption_short']
3
+
4
+ default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
5
+ default_short_prompt = 'caption this image in less than ten words'
6
+
7
+ default_replacements = [
8
+ ("the image features", ""),
9
+ ("the image shows", ""),
10
+ ("the image depicts", ""),
11
+ ("the image is", ""),
12
+ ("in this image", ""),
13
+ ("in the image", ""),
14
+ ]
15
+
16
+
17
+ def clean_caption(cap, replacements=None):
18
+ if replacements is None:
19
+ replacements = default_replacements
20
+
21
+ # remove any newlines
22
+ cap = cap.replace("\n", ", ")
23
+ cap = cap.replace("\r", ", ")
24
+ cap = cap.replace(".", ",")
25
+ cap = cap.replace("\"", "")
26
+
27
+ # remove unicode characters
28
+ cap = cap.encode('ascii', 'ignore').decode('ascii')
29
+
30
+ # make lowercase
31
+ cap = cap.lower()
32
+ # remove any extra spaces
33
+ cap = " ".join(cap.split())
34
+
35
+ for replacement in replacements:
36
+ if replacement[0].startswith('*'):
37
+ # we are removing all text if it starts with this and the rest matches
38
+ search_text = replacement[0][1:]
39
+ if cap.startswith(search_text):
40
+ cap = ""
41
+ else:
42
+ cap = cap.replace(replacement[0].lower(), replacement[1].lower())
43
+
44
+ cap_list = cap.split(",")
45
+ # trim whitespace
46
+ cap_list = [c.strip() for c in cap_list]
47
+ # remove empty strings
48
+ cap_list = [c for c in cap_list if c != ""]
49
+ # remove duplicates
50
+ cap_list = list(dict.fromkeys(cap_list))
51
+ # join back together
52
+ cap = ", ".join(cap_list)
53
+ return cap
extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Literal, Type, TYPE_CHECKING
3
+
4
+ Host: Type = Literal['unsplash', 'pexels']
5
+
6
+ RAW_DIR = "raw"
7
+ NEW_DIR = "_tmp"
8
+ TRAIN_DIR = "train"
9
+ DEPTH_DIR = "depth"
10
+
11
+ from .image_tools import Step, img_manipulation_steps
12
+ from .caption import caption_manipulation_steps
13
+
14
+
15
+ class DatasetSyncCollectionConfig:
16
+ def __init__(self, **kwargs):
17
+ self.host: Host = kwargs.get('host', None)
18
+ self.collection_id: str = kwargs.get('collection_id', None)
19
+ self.directory: str = kwargs.get('directory', None)
20
+ self.api_key: str = kwargs.get('api_key', None)
21
+ self.min_width: int = kwargs.get('min_width', 1024)
22
+ self.min_height: int = kwargs.get('min_height', 1024)
23
+
24
+ if self.host is None:
25
+ raise ValueError("host is required")
26
+ if self.collection_id is None:
27
+ raise ValueError("collection_id is required")
28
+ if self.directory is None:
29
+ raise ValueError("directory is required")
30
+ if self.api_key is None:
31
+ raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
32
+
33
+
34
+ class ImageState:
35
+ def __init__(self, **kwargs):
36
+ self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
37
+ self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
38
+
39
+ def to_dict(self):
40
+ return {
41
+ 'steps_complete': self.steps_complete
42
+ }
43
+
44
+
45
+ class Rect:
46
+ def __init__(self, **kwargs):
47
+ self.x = kwargs.get('x', 0)
48
+ self.y = kwargs.get('y', 0)
49
+ self.width = kwargs.get('width', 0)
50
+ self.height = kwargs.get('height', 0)
51
+
52
+ def to_dict(self):
53
+ return {
54
+ 'x': self.x,
55
+ 'y': self.y,
56
+ 'width': self.width,
57
+ 'height': self.height
58
+ }
59
+
60
+
61
+ class ImgInfo:
62
+ def __init__(self, **kwargs):
63
+ self.version: int = kwargs.get('version', None)
64
+ self.caption: str = kwargs.get('caption', None)
65
+ self.caption_short: str = kwargs.get('caption_short', None)
66
+ self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
67
+ self.state = ImageState(**kwargs.get('state', {}))
68
+ self.caption_method = kwargs.get('caption_method', None)
69
+ self.other_captions = kwargs.get('other_captions', {})
70
+ self._upgrade_state()
71
+ self.force_image_process: bool = False
72
+ self._requested_steps: list[Step] = []
73
+
74
+ self.is_dirty: bool = False
75
+
76
+ def _upgrade_state(self):
77
+ # upgrades older states
78
+ if self.caption is not None and 'caption' not in self.state.steps_complete:
79
+ self.mark_step_complete('caption')
80
+ self.is_dirty = True
81
+ if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
82
+ self.mark_step_complete('caption_short')
83
+ self.is_dirty = True
84
+ if self.caption_method is None and self.caption is not None:
85
+ # added caption method in version 2. Was all llava before that
86
+ self.caption_method = 'llava:default'
87
+ self.is_dirty = True
88
+
89
+ def to_dict(self):
90
+ return {
91
+ 'version': self.version,
92
+ 'caption_method': self.caption_method,
93
+ 'caption': self.caption,
94
+ 'caption_short': self.caption_short,
95
+ 'poi': [poi.to_dict() for poi in self.poi],
96
+ 'state': self.state.to_dict(),
97
+ 'other_captions': self.other_captions
98
+ }
99
+
100
+ def mark_step_complete(self, step: Step):
101
+ if step not in self.state.steps_complete:
102
+ self.state.steps_complete.append(step)
103
+ if step in self.state.steps_to_complete:
104
+ self.state.steps_to_complete.remove(step)
105
+ self.is_dirty = True
106
+
107
+ def add_step(self, step: Step):
108
+ if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
109
+ self.state.steps_to_complete.append(step)
110
+
111
+ def trigger_image_reprocess(self):
112
+ if self._requested_steps is None:
113
+ raise Exception("Must call add_steps before trigger_image_reprocess")
114
+ steps = self._requested_steps
115
+ # remove all image manipulationf from steps_to_complete
116
+ for step in img_manipulation_steps:
117
+ if step in self.state.steps_to_complete:
118
+ self.state.steps_to_complete.remove(step)
119
+ if step in self.state.steps_complete:
120
+ self.state.steps_complete.remove(step)
121
+ self.force_image_process = True
122
+ self.is_dirty = True
123
+ # we want to keep the order passed in process file
124
+ for step in steps:
125
+ if step in img_manipulation_steps:
126
+ self.add_step(step)
127
+
128
+ def add_steps(self, steps: list[Step]):
129
+ self._requested_steps = [step for step in steps]
130
+ for stage in steps:
131
+ self.add_step(stage)
132
+
133
+ # update steps if we have any img processes not complete, we have to reprocess them all
134
+ # if any steps_to_complete are in img_manipulation_steps
135
+
136
+ is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
137
+ order_has_changed = False
138
+
139
+ if not is_manipulating_image:
140
+ # check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
141
+ target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
142
+ current_img_manipulation_order = [step for step in self.state.steps_complete if
143
+ step in img_manipulation_steps]
144
+ if target_img_manipulation_order != current_img_manipulation_order:
145
+ order_has_changed = True
146
+
147
+ if is_manipulating_image or order_has_changed:
148
+ self.trigger_image_reprocess()
149
+
150
+ def set_caption_method(self, method: str):
151
+ if self._requested_steps is None:
152
+ raise Exception("Must call add_steps before set_caption_method")
153
+ if self.caption_method != method:
154
+ self.is_dirty = True
155
+ # move previous caption method to other_captions
156
+ if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
157
+ self.other_captions[self.caption_method] = {
158
+ 'caption': self.caption,
159
+ 'caption_short': self.caption_short,
160
+ }
161
+ self.caption_method = method
162
+ self.caption = None
163
+ self.caption_short = None
164
+ # see if we have a caption from the new method
165
+ if method in self.other_captions:
166
+ self.caption = self.other_captions[method].get('caption', None)
167
+ self.caption_short = self.other_captions[method].get('caption_short', None)
168
+ else:
169
+ self.trigger_new_caption()
170
+
171
+ def trigger_new_caption(self):
172
+ self.caption = None
173
+ self.caption_short = None
174
+ self.is_dirty = True
175
+ # check to see if we have any steps in the complete list and move them to the to_complete list
176
+ for step in self.state.steps_complete:
177
+ if step in caption_manipulation_steps:
178
+ self.state.steps_complete.remove(step)
179
+ self.state.steps_to_complete.append(step)
180
+
181
+ def to_json(self):
182
+ return json.dumps(self.to_dict())
183
+
184
+ def set_version(self, version: int):
185
+ if self.version != version:
186
+ self.is_dirty = True
187
+ self.version = version
extensions_built_in/dataset_tools/tools/fuyu_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
2
+
3
+ from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ class FuyuImageProcessor:
9
+ def __init__(self, device='cuda'):
10
+ from transformers import FuyuProcessor, FuyuForCausalLM
11
+ self.device = device
12
+ self.model: FuyuForCausalLM = None
13
+ self.processor: FuyuProcessor = None
14
+ self.dtype = torch.bfloat16
15
+ self.tokenizer: AutoTokenizer
16
+ self.is_loaded = False
17
+
18
+ def load_model(self):
19
+ from transformers import FuyuProcessor, FuyuForCausalLM
20
+ model_path = "adept/fuyu-8b"
21
+ kwargs = {"device_map": self.device}
22
+ kwargs['load_in_4bit'] = True
23
+ kwargs['quantization_config'] = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=self.dtype,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type='nf4'
28
+ )
29
+ self.processor = FuyuProcessor.from_pretrained(model_path)
30
+ self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
31
+ self.is_loaded = True
32
+
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
34
+ self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
35
+ self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
36
+
37
+ def generate_caption(
38
+ self, image: Image,
39
+ prompt: str = default_long_prompt,
40
+ replacements=default_replacements,
41
+ max_new_tokens=512
42
+ ):
43
+ # prepare inputs for the model
44
+ # text_prompt = f"{prompt}\n"
45
+
46
+ # image = image.convert('RGB')
47
+ model_inputs = self.processor(text=prompt, images=[image])
48
+ model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
49
+ model_inputs.items()}
50
+
51
+ generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
52
+ prompt_len = model_inputs["input_ids"].shape[-1]
53
+ output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
54
+ output = clean_caption(output, replacements=replacements)
55
+ return output
56
+
57
+ # inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
58
+ # for k, v in inputs.items():
59
+ # inputs[k] = v.to(self.device)
60
+
61
+ # # autoregressively generate text
62
+ # generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
63
+ # generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
64
+ # output = generation_text[0]
65
+ #
66
+ # return clean_caption(output, replacements=replacements)
extensions_built_in/dataset_tools/tools/image_tools.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Type, TYPE_CHECKING, Union
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image, ImageOps
6
+
7
+ Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
8
+
9
+ img_manipulation_steps = ['contrast_stretch']
10
+
11
+ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
12
+
13
+ if TYPE_CHECKING:
14
+ from .llava_utils import LLaVAImageProcessor
15
+ from .fuyu_utils import FuyuImageProcessor
16
+
17
+ ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
18
+
19
+
20
+ def pil_to_cv2(image):
21
+ """Convert a PIL image to a cv2 image."""
22
+ return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
23
+
24
+
25
+ def cv2_to_pil(image):
26
+ """Convert a cv2 image to a PIL image."""
27
+ return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
28
+
29
+
30
+ def load_image(img_path: str):
31
+ image = Image.open(img_path).convert('RGB')
32
+ try:
33
+ # transpose with exif data
34
+ image = ImageOps.exif_transpose(image)
35
+ except Exception as e:
36
+ pass
37
+ return image
38
+
39
+
40
+ def resize_to_max(image, max_width=1024, max_height=1024):
41
+ width, height = image.size
42
+ if width <= max_width and height <= max_height:
43
+ return image
44
+
45
+ scale = min(max_width / width, max_height / height)
46
+ width = int(width * scale)
47
+ height = int(height * scale)
48
+
49
+ return image.resize((width, height), Image.LANCZOS)
extensions_built_in/dataset_tools/tools/llava_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
3
+
4
+ import torch
5
+ from PIL import Image, ImageOps
6
+
7
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
8
+
9
+ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
10
+
11
+
12
+ class LLaVAImageProcessor:
13
+ def __init__(self, device='cuda'):
14
+ try:
15
+ from llava.model import LlavaLlamaForCausalLM
16
+ except ImportError:
17
+ # print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
18
+ print(
19
+ "You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
20
+ raise
21
+ self.device = device
22
+ self.model: LlavaLlamaForCausalLM = None
23
+ self.tokenizer: AutoTokenizer = None
24
+ self.image_processor: CLIPImageProcessor = None
25
+ self.is_loaded = False
26
+
27
+ def load_model(self):
28
+ from llava.model import LlavaLlamaForCausalLM
29
+
30
+ model_path = "4bit/llava-v1.5-13b-3GB"
31
+ # kwargs = {"device_map": "auto"}
32
+ kwargs = {"device_map": self.device}
33
+ kwargs['load_in_4bit'] = True
34
+ kwargs['quantization_config'] = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_compute_dtype=torch.float16,
37
+ bnb_4bit_use_double_quant=True,
38
+ bnb_4bit_quant_type='nf4'
39
+ )
40
+ self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
41
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
42
+ vision_tower = self.model.get_vision_tower()
43
+ if not vision_tower.is_loaded:
44
+ vision_tower.load_model()
45
+ vision_tower.to(device=self.device)
46
+ self.image_processor = vision_tower.image_processor
47
+ self.is_loaded = True
48
+
49
+ def generate_caption(
50
+ self, image:
51
+ Image, prompt: str = default_long_prompt,
52
+ replacements=default_replacements,
53
+ max_new_tokens=512
54
+ ):
55
+ from llava.conversation import conv_templates, SeparatorStyle
56
+ from llava.utils import disable_torch_init
57
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
58
+ from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
59
+ # question = "how many dogs are in the picture?"
60
+ disable_torch_init()
61
+ conv_mode = "llava_v0"
62
+ conv = conv_templates[conv_mode].copy()
63
+ roles = conv.roles
64
+ image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
65
+
66
+ inp = f"{roles[0]}: {prompt}"
67
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
68
+ conv.append_message(conv.roles[0], inp)
69
+ conv.append_message(conv.roles[1], None)
70
+ raw_prompt = conv.get_prompt()
71
+ input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
72
+ return_tensors='pt').unsqueeze(0).cuda()
73
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
74
+ keywords = [stop_str]
75
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
76
+ with torch.inference_mode():
77
+ output_ids = self.model.generate(
78
+ input_ids, images=image_tensor, do_sample=True, temperature=0.1,
79
+ max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
80
+ top_p=0.8
81
+ )
82
+ outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
83
+ conv.messages[-1][-1] = outputs
84
+ output = outputs.rsplit('</s>', 1)[0]
85
+ return clean_caption(output, replacements=replacements)
extensions_built_in/dataset_tools/tools/sync_tools.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tqdm
4
+ from typing import List, Optional, TYPE_CHECKING
5
+
6
+
7
+ def img_root_path(img_id: str):
8
+ return os.path.dirname(os.path.dirname(img_id))
9
+
10
+
11
+ if TYPE_CHECKING:
12
+ from .dataset_tools_config_modules import DatasetSyncCollectionConfig
13
+
14
+ img_exts = ['.jpg', '.jpeg', '.webp', '.png']
15
+
16
+ class Photo:
17
+ def __init__(
18
+ self,
19
+ id,
20
+ host,
21
+ width,
22
+ height,
23
+ url,
24
+ filename
25
+ ):
26
+ self.id = str(id)
27
+ self.host = host
28
+ self.width = width
29
+ self.height = height
30
+ self.url = url
31
+ self.filename = filename
32
+
33
+
34
+ def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
35
+ if img_width > img_height:
36
+ scale = min_height / img_height
37
+ else:
38
+ scale = min_width / img_width
39
+
40
+ new_width = int(img_width * scale)
41
+ new_height = int(img_height * scale)
42
+
43
+ return new_width, new_height
44
+
45
+
46
+ def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
47
+ all_images = []
48
+ next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
49
+
50
+ while True:
51
+ response = requests.get(next_page, headers={
52
+ "Authorization": f"{config.api_key}"
53
+ })
54
+ response.raise_for_status()
55
+ data = response.json()
56
+ all_images.extend(data['media'])
57
+ if 'next_page' in data and data['next_page']:
58
+ next_page = data['next_page']
59
+ else:
60
+ break
61
+
62
+ photos = []
63
+ for image in all_images:
64
+ new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
65
+ url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
66
+ filename = os.path.basename(image['src']['original'])
67
+
68
+ photos.append(Photo(
69
+ id=image['id'],
70
+ host="pexels",
71
+ width=image['width'],
72
+ height=image['height'],
73
+ url=url,
74
+ filename=filename
75
+ ))
76
+
77
+ return photos
78
+
79
+
80
+ def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
81
+ headers = {
82
+ # "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
83
+ "Authorization": f"Client-ID {config.api_key}"
84
+ }
85
+ # headers['Authorization'] = f"Bearer {token}"
86
+
87
+ url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
88
+ response = requests.get(url, headers=headers)
89
+ response.raise_for_status()
90
+ res_headers = response.headers
91
+ # parse the link header to get the next page
92
+ # 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
93
+ has_next_page = False
94
+ if 'Link' in res_headers:
95
+ has_next_page = True
96
+ link_header = res_headers['Link']
97
+ link_header = link_header.split(',')
98
+ link_header = [link.strip() for link in link_header]
99
+ link_header = [link.split(';') for link in link_header]
100
+ link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
101
+ link_header = {link[1]: link[0] for link in link_header}
102
+
103
+ # get page number from last url
104
+ last_page = link_header['rel="last']
105
+ last_page = last_page.split('?')[1]
106
+ last_page = last_page.split('&')
107
+ last_page = [param.split('=') for param in last_page]
108
+ last_page = {param[0]: param[1] for param in last_page}
109
+ last_page = int(last_page['page'])
110
+
111
+ all_images = response.json()
112
+
113
+ if has_next_page:
114
+ # assume we start on page 1, so we don't need to get it again
115
+ for page in tqdm.tqdm(range(2, last_page + 1)):
116
+ url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
117
+ response = requests.get(url, headers=headers)
118
+ response.raise_for_status()
119
+ all_images.extend(response.json())
120
+
121
+ photos = []
122
+ for image in all_images:
123
+ new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
124
+ url = f"{image['urls']['raw']}&w={new_width}"
125
+ filename = f"{image['id']}.jpg"
126
+
127
+ photos.append(Photo(
128
+ id=image['id'],
129
+ host="unsplash",
130
+ width=image['width'],
131
+ height=image['height'],
132
+ url=url,
133
+ filename=filename
134
+ ))
135
+
136
+ return photos
137
+
138
+
139
+ def get_img_paths(dir_path: str):
140
+ os.makedirs(dir_path, exist_ok=True)
141
+ local_files = os.listdir(dir_path)
142
+ # remove non image files
143
+ local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
144
+ # make full path
145
+ local_files = [os.path.join(dir_path, file) for file in local_files]
146
+ return local_files
147
+
148
+
149
+ def get_local_image_ids(dir_path: str):
150
+ os.makedirs(dir_path, exist_ok=True)
151
+ local_files = get_img_paths(dir_path)
152
+ # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
153
+ return set([os.path.basename(file).split('.')[0] for file in local_files])
154
+
155
+
156
+ def get_local_image_file_names(dir_path: str):
157
+ os.makedirs(dir_path, exist_ok=True)
158
+ local_files = get_img_paths(dir_path)
159
+ # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
160
+ return set([os.path.basename(file) for file in local_files])
161
+
162
+
163
+ def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
164
+ img_width = photo.width
165
+ img_height = photo.height
166
+
167
+ if img_width < min_width or img_height < min_height:
168
+ raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
169
+
170
+ img_response = requests.get(photo.url)
171
+ img_response.raise_for_status()
172
+ os.makedirs(dir_path, exist_ok=True)
173
+
174
+ filename = os.path.join(dir_path, photo.filename)
175
+ with open(filename, 'wb') as file:
176
+ file.write(img_response.content)
177
+
178
+
179
+ def update_caption(img_path: str):
180
+ # if the caption is a txt file, convert it to a json file
181
+ filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
182
+ # see if it exists
183
+ if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
184
+ # todo add poi and what not
185
+ return # we have a json file
186
+ caption = ""
187
+ # see if txt file exists
188
+ if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
189
+ # read it
190
+ with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
191
+ caption = file.read()
192
+ # write json file
193
+ with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
194
+ file.write(f'{{"caption": "{caption}"}}')
195
+
196
+ # delete txt file
197
+ os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
198
+
199
+
200
+ # def equalize_img(img_path: str):
201
+ # input_path = img_path
202
+ # output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
203
+ # os.makedirs(os.path.dirname(output_path), exist_ok=True)
204
+ # process_img(
205
+ # img_path=input_path,
206
+ # output_path=output_path,
207
+ # equalize=True,
208
+ # max_size=2056,
209
+ # white_balance=False,
210
+ # gamma_correction=False,
211
+ # strength=0.6,
212
+ # )
213
+
214
+
215
+ # def annotate_depth(img_path: str):
216
+ # # make fake args
217
+ # args = argparse.Namespace()
218
+ # args.annotator = "midas"
219
+ # args.res = 1024
220
+ #
221
+ # img = cv2.imread(img_path)
222
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
223
+ #
224
+ # output = annotate(img, args)
225
+ #
226
+ # output = output.astype('uint8')
227
+ # output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
228
+ #
229
+ # os.makedirs(os.path.dirname(img_path), exist_ok=True)
230
+ # output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
231
+ #
232
+ # cv2.imwrite(output_path, output)
233
+
234
+
235
+ # def invert_depth(img_path: str):
236
+ # img = cv2.imread(img_path)
237
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
238
+ # # invert the colors
239
+ # img = cv2.bitwise_not(img)
240
+ #
241
+ # os.makedirs(os.path.dirname(img_path), exist_ok=True)
242
+ # output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
243
+ # cv2.imwrite(output_path, img)
244
+
245
+
246
+ #
247
+ # # update our list of raw images
248
+ # raw_images = get_img_paths(raw_dir)
249
+ #
250
+ # # update raw captions
251
+ # for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
252
+ # update_caption(image_id)
253
+ #
254
+ # # equalize images
255
+ # for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
256
+ # if img_path not in eq_images:
257
+ # equalize_img(img_path)
258
+ #
259
+ # # update our list of eq images
260
+ # eq_images = get_img_paths(eq_dir)
261
+ # # update eq captions
262
+ # for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
263
+ # update_caption(image_id)
264
+ #
265
+ # # annotate depth
266
+ # depth_dir = os.path.join(root_dir, DEPTH_DIR)
267
+ # depth_images = get_img_paths(depth_dir)
268
+ # for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
269
+ # if img_path not in depth_images:
270
+ # annotate_depth(img_path)
271
+ #
272
+ # depth_images = get_img_paths(depth_dir)
273
+ #
274
+ # # invert depth
275
+ # inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
276
+ # inv_depth_images = get_img_paths(inv_depth_dir)
277
+ # for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
278
+ # if img_path not in inv_depth_images:
279
+ # invert_depth(img_path)
extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from collections import OrderedDict
4
+ import os
5
+ from contextlib import nullcontext
6
+ from typing import Optional, Union, List
7
+ from torch.utils.data import ConcatDataset, DataLoader
8
+
9
+ from toolkit.config_modules import ReferenceDatasetConfig
10
+ from toolkit.data_loader import PairedImageDataset
11
+ from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
12
+ from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
13
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
14
+ import gc
15
+ from toolkit import train_tools
16
+ import torch
17
+ from jobs.process import BaseSDTrainProcess
18
+ import random
19
+ from toolkit.basic import value_map
20
+
21
+
22
+ def flush():
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+
26
+
27
+ class ReferenceSliderConfig:
28
+ def __init__(self, **kwargs):
29
+ self.additional_losses: List[str] = kwargs.get('additional_losses', [])
30
+ self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
31
+ self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
32
+
33
+
34
+ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
35
+ sd: StableDiffusion
36
+ data_loader: DataLoader = None
37
+
38
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
39
+ super().__init__(process_id, job, config, **kwargs)
40
+ self.prompt_txt_list = None
41
+ self.step_num = 0
42
+ self.start_step = 0
43
+ self.device = self.get_conf('device', self.job.device)
44
+ self.device_torch = torch.device(self.device)
45
+ self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {}))
46
+
47
+ def load_datasets(self):
48
+ if self.data_loader is None:
49
+ print(f"Loading datasets")
50
+ datasets = []
51
+ for dataset in self.slider_config.datasets:
52
+ print(f" - Dataset: {dataset.pair_folder}")
53
+ config = {
54
+ 'path': dataset.pair_folder,
55
+ 'size': dataset.size,
56
+ 'default_prompt': dataset.target_class,
57
+ 'network_weight': dataset.network_weight,
58
+ 'pos_weight': dataset.pos_weight,
59
+ 'neg_weight': dataset.neg_weight,
60
+ 'pos_folder': dataset.pos_folder,
61
+ 'neg_folder': dataset.neg_folder,
62
+ }
63
+ image_dataset = PairedImageDataset(config)
64
+ datasets.append(image_dataset)
65
+
66
+ concatenated_dataset = ConcatDataset(datasets)
67
+ self.data_loader = DataLoader(
68
+ concatenated_dataset,
69
+ batch_size=self.train_config.batch_size,
70
+ shuffle=True,
71
+ num_workers=2
72
+ )
73
+
74
+ def before_model_load(self):
75
+ pass
76
+
77
+ def hook_before_train_loop(self):
78
+ self.sd.vae.eval()
79
+ self.sd.vae.to(self.device_torch)
80
+ self.load_datasets()
81
+
82
+ pass
83
+
84
+ def hook_train_loop(self, batch):
85
+ with torch.no_grad():
86
+ imgs, prompts, network_weights = batch
87
+ network_pos_weight, network_neg_weight = network_weights
88
+
89
+ if isinstance(network_pos_weight, torch.Tensor):
90
+ network_pos_weight = network_pos_weight.item()
91
+ if isinstance(network_neg_weight, torch.Tensor):
92
+ network_neg_weight = network_neg_weight.item()
93
+
94
+ # get an array of random floats between -weight_jitter and weight_jitter
95
+ loss_jitter_multiplier = 1.0
96
+ weight_jitter = self.slider_config.weight_jitter
97
+ if weight_jitter > 0.0:
98
+ jitter_list = random.uniform(-weight_jitter, weight_jitter)
99
+ orig_network_pos_weight = network_pos_weight
100
+ network_pos_weight += jitter_list
101
+ network_neg_weight += (jitter_list * -1.0)
102
+ # penalize the loss for its distance from network_pos_weight
103
+ # a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
104
+ # so the loss_jitter_multiplier needs to be 0.4
105
+ loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
106
+
107
+
108
+ # if items in network_weight list are tensors, convert them to floats
109
+
110
+ dtype = get_torch_dtype(self.train_config.dtype)
111
+ imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
112
+ # split batched images in half so left is negative and right is positive
113
+ negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
114
+
115
+ positive_latents = self.sd.encode_images(positive_images)
116
+ negative_latents = self.sd.encode_images(negative_images)
117
+
118
+ height = positive_images.shape[2]
119
+ width = positive_images.shape[3]
120
+ batch_size = positive_images.shape[0]
121
+
122
+ if self.train_config.gradient_checkpointing:
123
+ # may get disabled elsewhere
124
+ self.sd.unet.enable_gradient_checkpointing()
125
+
126
+ noise_scheduler = self.sd.noise_scheduler
127
+ optimizer = self.optimizer
128
+ lr_scheduler = self.lr_scheduler
129
+
130
+ self.sd.noise_scheduler.set_timesteps(
131
+ self.train_config.max_denoising_steps, device=self.device_torch
132
+ )
133
+
134
+ timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
135
+ timesteps = timesteps.long()
136
+
137
+ # get noise
138
+ noise_positive = self.sd.get_latent_noise(
139
+ pixel_height=height,
140
+ pixel_width=width,
141
+ batch_size=batch_size,
142
+ noise_offset=self.train_config.noise_offset,
143
+ ).to(self.device_torch, dtype=dtype)
144
+
145
+ noise_negative = noise_positive.clone()
146
+
147
+ # Add noise to the latents according to the noise magnitude at each timestep
148
+ # (this is the forward diffusion process)
149
+ noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
150
+ noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
151
+
152
+ noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
153
+ noise = torch.cat([noise_positive, noise_negative], dim=0)
154
+ timesteps = torch.cat([timesteps, timesteps], dim=0)
155
+ network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
156
+
157
+ self.optimizer.zero_grad()
158
+ noisy_latents.requires_grad = False
159
+
160
+ # if training text encoder enable grads, else do context of no grad
161
+ with torch.set_grad_enabled(self.train_config.train_text_encoder):
162
+ # fix issue with them being tuples sometimes
163
+ prompt_list = []
164
+ for prompt in prompts:
165
+ if isinstance(prompt, tuple):
166
+ prompt = prompt[0]
167
+ prompt_list.append(prompt)
168
+ conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
169
+ conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
170
+
171
+ # if self.model_config.is_xl:
172
+ # # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
173
+ # network_multiplier_list = network_multiplier
174
+ # noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
175
+ # noise_list = torch.chunk(noise, 2, dim=0)
176
+ # timesteps_list = torch.chunk(timesteps, 2, dim=0)
177
+ # conditional_embeds_list = split_prompt_embeds(conditional_embeds)
178
+ # else:
179
+ network_multiplier_list = [network_multiplier]
180
+ noisy_latent_list = [noisy_latents]
181
+ noise_list = [noise]
182
+ timesteps_list = [timesteps]
183
+ conditional_embeds_list = [conditional_embeds]
184
+
185
+ losses = []
186
+ # allow to chunk it out to save vram
187
+ for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip(
188
+ network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list
189
+ ):
190
+ with self.network:
191
+ assert self.network.is_active
192
+
193
+ self.network.multiplier = network_multiplier
194
+
195
+ noise_pred = self.sd.predict_noise(
196
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
197
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
198
+ timestep=timesteps,
199
+ )
200
+ noise = noise.to(self.device_torch, dtype=dtype)
201
+
202
+ if self.sd.prediction_type == 'v_prediction':
203
+ # v-parameterization training
204
+ target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
205
+ else:
206
+ target = noise
207
+
208
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
209
+ loss = loss.mean([1, 2, 3])
210
+
211
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
212
+ # add min_snr_gamma
213
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
214
+
215
+ loss = loss.mean() * loss_jitter_multiplier
216
+
217
+ loss_float = loss.item()
218
+ losses.append(loss_float)
219
+
220
+ # back propagate loss to free ram
221
+ loss.backward()
222
+
223
+ # apply gradients
224
+ optimizer.step()
225
+ lr_scheduler.step()
226
+
227
+ # reset network
228
+ self.network.multiplier = 1.0
229
+
230
+ loss_dict = OrderedDict(
231
+ {'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0}
232
+ )
233
+
234
+ return loss_dict
235
+ # end hook_train_loop
extensions_built_in/image_reference_slider_trainer/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ImageReferenceSliderTrainer(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "image_reference_slider_trainer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Image Reference Slider Trainer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess
19
+ return ImageReferenceSliderTrainerProcess
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ImageReferenceSliderTrainer
25
+ ]
extensions_built_in/image_reference_slider_trainer/config/train.example.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: example_name
5
+ process:
6
+ - type: 'image_reference_slider_trainer'
7
+ training_folder: "/mnt/Train/out/LoRA"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "/home/jaret/Dev/.tensorboard"
11
+ network:
12
+ type: "lora"
13
+ linear: 8
14
+ linear_alpha: 8
15
+ train:
16
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
17
+ steps: 5000
18
+ lr: 1e-4
19
+ train_unet: true
20
+ gradient_checkpointing: true
21
+ train_text_encoder: true
22
+ optimizer: "adamw"
23
+ optimizer_params:
24
+ weight_decay: 1e-2
25
+ lr_scheduler: "constant"
26
+ max_denoising_steps: 1000
27
+ batch_size: 1
28
+ dtype: bf16
29
+ xformers: true
30
+ skip_first_sample: true
31
+ noise_offset: 0.0
32
+ model:
33
+ name_or_path: "/path/to/model.safetensors"
34
+ is_v2: false # for v2 models
35
+ is_xl: false # for SDXL models
36
+ is_v_pred: false # for v-prediction models (most v2 models)
37
+ save:
38
+ dtype: float16 # precision to save
39
+ save_every: 1000 # save every this many steps
40
+ max_step_saves_to_keep: 2 # only affects step counts
41
+ sample:
42
+ sampler: "ddpm" # must match train.noise_scheduler
43
+ sample_every: 100 # sample every this many steps
44
+ width: 512
45
+ height: 512
46
+ prompts:
47
+ - "photo of a woman with red hair taking a selfie --m -3"
48
+ - "photo of a woman with red hair taking a selfie --m -1"
49
+ - "photo of a woman with red hair taking a selfie --m 1"
50
+ - "photo of a woman with red hair taking a selfie --m 3"
51
+ - "close up photo of a man smiling at the camera, in a tank top --m -3"
52
+ - "close up photo of a man smiling at the camera, in a tank top--m -1"
53
+ - "close up photo of a man smiling at the camera, in a tank top --m 1"
54
+ - "close up photo of a man smiling at the camera, in a tank top --m 3"
55
+ - "photo of a blonde woman smiling, barista --m -3"
56
+ - "photo of a blonde woman smiling, barista --m -1"
57
+ - "photo of a blonde woman smiling, barista --m 1"
58
+ - "photo of a blonde woman smiling, barista --m 3"
59
+ - "photo of a Christina Hendricks --m -1"
60
+ - "photo of a Christina Hendricks --m -1"
61
+ - "photo of a Christina Hendricks --m 1"
62
+ - "photo of a Christina Hendricks --m 3"
63
+ - "photo of a Christina Ricci --m -3"
64
+ - "photo of a Christina Ricci --m -1"
65
+ - "photo of a Christina Ricci --m 1"
66
+ - "photo of a Christina Ricci --m 3"
67
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
68
+ seed: 42
69
+ walk_seed: false
70
+ guidance_scale: 7
71
+ sample_steps: 20
72
+ network_multiplier: 1.0
73
+
74
+ logging:
75
+ log_every: 10 # log every this many steps
76
+ use_wandb: false # not supported yet
77
+ verbose: false
78
+
79
+ slider:
80
+ datasets:
81
+ - pair_folder: "/path/to/folder/side/by/side/images"
82
+ network_weight: 2.0
83
+ target_class: "" # only used as default if caption txt are not present
84
+ size: 512
85
+ - pair_folder: "/path/to/folder/side/by/side/images"
86
+ network_weight: 4.0
87
+ target_class: "" # only used as default if caption txt are not present
88
+ size: 512
89
+
90
+
91
+ # you can put any information you want here, and it will be saved in the model
92
+ # the below is an example. I recommend doing trigger words at a minimum
93
+ # in the metadata. The software will include this plus some other information
94
+ meta:
95
+ name: "[name]" # [name] gets replaced with the name above
96
+ description: A short description of your model
97
+ trigger_words:
98
+ - put
99
+ - trigger
100
+ - words
101
+ - here
102
+ version: '0.1'
103
+ creator:
104
+ name: Your Name
105
106
+ website: https://yourwebsite.com
107
+ any: All meta data above is arbitrary, it can be whatever you want.
extensions_built_in/sd_trainer/SDTrainer.py ADDED
@@ -0,0 +1,1679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import Union, Literal, List, Optional
5
+
6
+ import numpy as np
7
+ from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
8
+
9
+ import torch.functional as F
10
+ from safetensors.torch import load_file
11
+ from torch.utils.data import DataLoader, ConcatDataset
12
+
13
+ from toolkit import train_tools
14
+ from toolkit.basic import value_map, adain, get_mean_std
15
+ from toolkit.clip_vision_adapter import ClipVisionAdapter
16
+ from toolkit.config_modules import GuidanceConfig
17
+ from toolkit.data_loader import get_dataloader_datasets
18
+ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
19
+ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
20
+ from toolkit.image_utils import show_tensors, show_latents
21
+ from toolkit.ip_adapter import IPAdapter
22
+ from toolkit.custom_adapter import CustomAdapter
23
+ from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
24
+ from toolkit.reference_adapter import ReferenceAdapter
25
+ from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
26
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
27
+ apply_learnable_snr_gos, LearnableSNRGamma
28
+ import gc
29
+ import torch
30
+ from jobs.process import BaseSDTrainProcess
31
+ from torchvision import transforms
32
+ from diffusers import EMAModel
33
+ import math
34
+ from toolkit.train_tools import precondition_model_outputs_flow_match
35
+
36
+
37
+ def flush():
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+
42
+ adapter_transforms = transforms.Compose([
43
+ transforms.ToTensor(),
44
+ ])
45
+
46
+
47
+ class SDTrainer(BaseSDTrainProcess):
48
+
49
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
50
+ super().__init__(process_id, job, config, **kwargs)
51
+ self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
52
+ self.do_prior_prediction = False
53
+ self.do_long_prompts = False
54
+ self.do_guided_loss = False
55
+ self.taesd: Optional[AutoencoderTiny] = None
56
+
57
+ self._clip_image_embeds_unconditional: Union[List[str], None] = None
58
+ self.negative_prompt_pool: Union[List[str], None] = None
59
+ self.batch_negative_prompt: Union[List[str], None] = None
60
+
61
+ self.scaler = torch.cuda.amp.GradScaler()
62
+
63
+ self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
64
+
65
+ self.do_grad_scale = True
66
+ if self.is_fine_tuning and self.is_bfloat:
67
+ self.do_grad_scale = False
68
+ if self.adapter_config is not None:
69
+ if self.adapter_config.train:
70
+ self.do_grad_scale = False
71
+
72
+ if self.train_config.dtype in ["fp16", "float16"]:
73
+ # patch the scaler to allow fp16 training
74
+ org_unscale_grads = self.scaler._unscale_grads_
75
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
76
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
77
+ self.scaler._unscale_grads_ = _unscale_grads_replacer
78
+
79
+ self.cached_blank_embeds: Optional[PromptEmbeds] = None
80
+ self.cached_trigger_embeds: Optional[PromptEmbeds] = None
81
+
82
+
83
+ def before_model_load(self):
84
+ pass
85
+
86
+ def before_dataset_load(self):
87
+ self.assistant_adapter = None
88
+ # get adapter assistant if one is set
89
+ if self.train_config.adapter_assist_name_or_path is not None:
90
+ adapter_path = self.train_config.adapter_assist_name_or_path
91
+
92
+ if self.train_config.adapter_assist_type == "t2i":
93
+ # dont name this adapter since we are not training it
94
+ self.assistant_adapter = T2IAdapter.from_pretrained(
95
+ adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
96
+ ).to(self.device_torch)
97
+ elif self.train_config.adapter_assist_type == "control_net":
98
+ self.assistant_adapter = ControlNetModel.from_pretrained(
99
+ adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
100
+ ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
101
+ else:
102
+ raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
103
+
104
+ self.assistant_adapter.eval()
105
+ self.assistant_adapter.requires_grad_(False)
106
+ flush()
107
+ if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
108
+ if self.model_config.is_xl:
109
+ self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
110
+ torch_dtype=get_torch_dtype(self.train_config.dtype))
111
+ else:
112
+ self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
113
+ torch_dtype=get_torch_dtype(self.train_config.dtype))
114
+ self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
115
+ self.taesd.eval()
116
+ self.taesd.requires_grad_(False)
117
+
118
+ def hook_before_train_loop(self):
119
+ super().hook_before_train_loop()
120
+
121
+ if self.train_config.do_prior_divergence:
122
+ self.do_prior_prediction = True
123
+ # move vae to device if we did not cache latents
124
+ if not self.is_latents_cached:
125
+ self.sd.vae.eval()
126
+ self.sd.vae.to(self.device_torch)
127
+ else:
128
+ # offload it. Already cached
129
+ self.sd.vae.to('cpu')
130
+ flush()
131
+ add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
132
+ if self.adapter is not None:
133
+ self.adapter.to(self.device_torch)
134
+
135
+ # check if we have regs and using adapter and caching clip embeddings
136
+ has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
137
+ is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
138
+
139
+ if has_reg and is_caching_clip_embeddings:
140
+ # we need a list of unconditional clip image embeds from other datasets to handle regs
141
+ unconditional_clip_image_embeds = []
142
+ datasets = get_dataloader_datasets(self.data_loader)
143
+ for i in range(len(datasets)):
144
+ unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
145
+
146
+ if len(unconditional_clip_image_embeds) == 0:
147
+ raise ValueError("No unconditional clip image embeds found. This should not happen")
148
+
149
+ self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
150
+
151
+ if self.train_config.negative_prompt is not None:
152
+ if os.path.exists(self.train_config.negative_prompt):
153
+ with open(self.train_config.negative_prompt, 'r') as f:
154
+ self.negative_prompt_pool = f.readlines()
155
+ # remove empty
156
+ self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""]
157
+ else:
158
+ # single prompt
159
+ self.negative_prompt_pool = [self.train_config.negative_prompt]
160
+
161
+ # handle unload text encoder
162
+ if self.train_config.unload_text_encoder:
163
+ with torch.no_grad():
164
+ if self.train_config.train_text_encoder:
165
+ raise ValueError("Cannot unload text encoder if training text encoder")
166
+ # cache embeddings
167
+
168
+ print("\n***** UNLOADING TEXT ENCODER *****")
169
+ print("This will train only with a blank prompt or trigger word, if set")
170
+ print("If this is not what you want, remove the unload_text_encoder flag")
171
+ print("***********************************")
172
+ print("")
173
+ self.sd.text_encoder_to(self.device_torch)
174
+ self.cached_blank_embeds = self.sd.encode_prompt("")
175
+ if self.trigger_word is not None:
176
+ self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
177
+
178
+ # move back to cpu
179
+ self.sd.text_encoder_to('cpu')
180
+ flush()
181
+
182
+
183
+ def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
184
+ # to process turbo learning, we make one big step from our current timestep to the end
185
+ # we then denoise the prediction on that remaining step and target our loss to our target latents
186
+ # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
187
+ # needs to be done on each item in batch as they may all have different timesteps
188
+ batch_size = pred.shape[0]
189
+ pred_chunks = torch.chunk(pred, batch_size, dim=0)
190
+ noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
191
+ timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
192
+ latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
193
+ noise_chunks = torch.chunk(noise, batch_size, dim=0)
194
+
195
+ with torch.no_grad():
196
+ # set the timesteps to 1000 so we can capture them to calculate the sigmas
197
+ self.sd.noise_scheduler.set_timesteps(
198
+ self.sd.noise_scheduler.config.num_train_timesteps,
199
+ device=self.device_torch
200
+ )
201
+ train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
202
+
203
+ train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
204
+
205
+ # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
206
+ self.sd.noise_scheduler.set_timesteps(
207
+ 1,
208
+ device=self.device_torch
209
+ )
210
+
211
+ denoised_pred_chunks = []
212
+ target_pred_chunks = []
213
+
214
+ for i in range(batch_size):
215
+ pred_item = pred_chunks[i]
216
+ noisy_latents_item = noisy_latents_chunks[i]
217
+ timesteps_item = timesteps_chunks[i]
218
+ latents_item = latent_chunks[i]
219
+ noise_item = noise_chunks[i]
220
+ with torch.no_grad():
221
+ timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
222
+ single_step_timestep_schedule = [timesteps_item.squeeze().item()]
223
+ # extract the sigma idx for our midpoint timestep
224
+ sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch)
225
+
226
+ end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
227
+ end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch)
228
+
229
+ # add noise to our target
230
+
231
+ # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
232
+ # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
233
+ self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
234
+ # set our single timstep
235
+ self.sd.noise_scheduler.timesteps = torch.from_numpy(
236
+ np.array(single_step_timestep_schedule, dtype=np.float32)
237
+ ).to(device=self.device_torch)
238
+
239
+ # set the step index to None so it will be recalculated on first step
240
+ self.sd.noise_scheduler._step_index = None
241
+
242
+ denoised_latent = self.sd.noise_scheduler.step(
243
+ pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
244
+ )[0]
245
+
246
+ residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
247
+ self.train_config.dtype))
248
+ # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
249
+ denoised_latent = denoised_latent - residual_noise
250
+
251
+ denoised_pred_chunks.append(denoised_latent)
252
+
253
+ denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
254
+ # set the scheduler back to the original timesteps
255
+ self.sd.noise_scheduler.set_timesteps(
256
+ self.sd.noise_scheduler.config.num_train_timesteps,
257
+ device=self.device_torch
258
+ )
259
+
260
+ output = denoised_latents / self.sd.vae.config['scaling_factor']
261
+ output = self.sd.vae.decode(output).sample
262
+
263
+ if self.train_config.show_turbo_outputs:
264
+ # since we are completely denoising, we can show them here
265
+ with torch.no_grad():
266
+ show_tensors(output)
267
+
268
+ # we return our big partial step denoised latents as our pred and our untouched latents as our target.
269
+ # you can do mse against the two here or run the denoised through the vae for pixel space loss against the
270
+ # input tensor images.
271
+
272
+ return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
273
+
274
+ # you can expand these in a child class to make customization easier
275
+ def calculate_loss(
276
+ self,
277
+ noise_pred: torch.Tensor,
278
+ noise: torch.Tensor,
279
+ noisy_latents: torch.Tensor,
280
+ timesteps: torch.Tensor,
281
+ batch: 'DataLoaderBatchDTO',
282
+ mask_multiplier: Union[torch.Tensor, float] = 1.0,
283
+ prior_pred: Union[torch.Tensor, None] = None,
284
+ **kwargs
285
+ ):
286
+ loss_target = self.train_config.loss_target
287
+ is_reg = any(batch.get_is_reg_list())
288
+
289
+ prior_mask_multiplier = None
290
+ target_mask_multiplier = None
291
+ dtype = get_torch_dtype(self.train_config.dtype)
292
+
293
+ has_mask = batch.mask_tensor is not None
294
+
295
+ with torch.no_grad():
296
+ loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
297
+
298
+ if self.train_config.match_noise_norm:
299
+ # match the norm of the noise
300
+ noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
301
+ noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
302
+ noise_pred = noise_pred * (noise_norm / noise_pred_norm)
303
+
304
+ if self.train_config.pred_scaler != 1.0:
305
+ noise_pred = noise_pred * self.train_config.pred_scaler
306
+
307
+ target = None
308
+
309
+ if self.train_config.target_noise_multiplier != 1.0:
310
+ noise = noise * self.train_config.target_noise_multiplier
311
+
312
+ if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
313
+ if self.train_config.correct_pred_norm and not is_reg:
314
+ with torch.no_grad():
315
+ # this only works if doing a prior pred
316
+ if prior_pred is not None:
317
+ prior_mean = prior_pred.mean([2,3], keepdim=True)
318
+ prior_std = prior_pred.std([2,3], keepdim=True)
319
+ noise_mean = noise_pred.mean([2,3], keepdim=True)
320
+ noise_std = noise_pred.std([2,3], keepdim=True)
321
+
322
+ mean_adjust = prior_mean - noise_mean
323
+ std_adjust = prior_std - noise_std
324
+
325
+ mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
326
+ std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
327
+
328
+ target_mean = noise_mean + mean_adjust
329
+ target_std = noise_std + std_adjust
330
+
331
+ eps = 1e-5
332
+ # match the noise to the prior
333
+ noise = (noise - noise_mean) / (noise_std + eps)
334
+ noise = noise * (target_std + eps) + target_mean
335
+ noise = noise.detach()
336
+
337
+ if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
338
+ assert not self.train_config.train_turbo
339
+ with torch.no_grad():
340
+ # we need to make the noise prediction be a masked blending of noise and prior_pred
341
+ stretched_mask_multiplier = value_map(
342
+ mask_multiplier,
343
+ batch.file_items[0].dataset_config.mask_min_value,
344
+ 1.0,
345
+ 0.0,
346
+ 1.0
347
+ )
348
+
349
+ prior_mask_multiplier = 1.0 - stretched_mask_multiplier
350
+
351
+
352
+ # target_mask_multiplier = mask_multiplier
353
+ # mask_multiplier = 1.0
354
+ target = noise
355
+ # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
356
+ # set masked multiplier to 1.0 so we dont double apply it
357
+ # mask_multiplier = 1.0
358
+ elif prior_pred is not None and not self.train_config.do_prior_divergence:
359
+ assert not self.train_config.train_turbo
360
+ # matching adapter prediction
361
+ target = prior_pred
362
+ elif self.sd.prediction_type == 'v_prediction':
363
+ # v-parameterization training
364
+ target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
365
+
366
+ elif self.sd.is_flow_matching:
367
+ target = (noise - batch.latents).detach()
368
+ else:
369
+ target = noise
370
+
371
+ if target is None:
372
+ target = noise
373
+
374
+ pred = noise_pred
375
+
376
+ if self.train_config.train_turbo:
377
+ pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
378
+
379
+ ignore_snr = False
380
+
381
+ if loss_target == 'source' or loss_target == 'unaugmented':
382
+ assert not self.train_config.train_turbo
383
+ # ignore_snr = True
384
+ if batch.sigmas is None:
385
+ raise ValueError("Batch sigmas is None. This should not happen")
386
+
387
+ # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
388
+ denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
389
+ weighing = batch.sigmas ** -2.0
390
+ if loss_target == 'source':
391
+ # denoise the latent and compare to the latent in the batch
392
+ target = batch.latents
393
+ elif loss_target == 'unaugmented':
394
+ # we have to encode images into latents for now
395
+ # we also denoise as the unaugmented tensor is not a noisy diffirental
396
+ with torch.no_grad():
397
+ unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
398
+ unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
399
+ target = unaugmented_latents.detach()
400
+
401
+ # Get the target for loss depending on the prediction type
402
+ if self.sd.noise_scheduler.config.prediction_type == "epsilon":
403
+ target = target # we are computing loss against denoise latents
404
+ elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
405
+ target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
406
+ else:
407
+ raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
408
+
409
+ # mse loss without reduction
410
+ loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
411
+ loss = loss_per_element
412
+ else:
413
+
414
+ if self.train_config.loss_type == "mae":
415
+ loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
416
+ else:
417
+ loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
418
+
419
+ # handle linear timesteps and only adjust the weight of the timesteps
420
+ if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
421
+ # calculate the weights for the timesteps
422
+ timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
423
+ timesteps,
424
+ v2=self.train_config.linear_timesteps2
425
+ ).to(loss.device, dtype=loss.dtype)
426
+ timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
427
+ loss = loss * timestep_weight
428
+
429
+ if self.train_config.do_prior_divergence and prior_pred is not None:
430
+ loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
431
+
432
+ if self.train_config.train_turbo:
433
+ mask_multiplier = mask_multiplier[:, 3:, :, :]
434
+ # resize to the size of the loss
435
+ mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
436
+
437
+ # multiply by our mask
438
+ loss = loss * mask_multiplier
439
+
440
+ prior_loss = None
441
+ if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
442
+ assert not self.train_config.train_turbo
443
+ if self.train_config.loss_type == "mae":
444
+ prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
445
+ else:
446
+ prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
447
+
448
+ prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
449
+ if torch.isnan(prior_loss).any():
450
+ print("Prior loss is nan")
451
+ prior_loss = None
452
+ else:
453
+ prior_loss = prior_loss.mean([1, 2, 3])
454
+ # loss = loss + prior_loss
455
+ # loss = loss + prior_loss
456
+ # loss = loss + prior_loss
457
+ loss = loss.mean([1, 2, 3])
458
+ # apply loss multiplier before prior loss
459
+ loss = loss * loss_multiplier
460
+ if prior_loss is not None:
461
+ loss = loss + prior_loss
462
+
463
+ if not self.train_config.train_turbo:
464
+ if self.train_config.learnable_snr_gos:
465
+ # add snr_gamma
466
+ loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
467
+ elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
468
+ # add snr_gamma
469
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
470
+ fixed=True)
471
+ elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
472
+ # add min_snr_gamma
473
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
474
+
475
+ loss = loss.mean()
476
+
477
+ # check for additional losses
478
+ if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
479
+
480
+ loss = loss + self.adapter.additional_loss.mean()
481
+ self.adapter.additional_loss = None
482
+
483
+ if self.train_config.target_norm_std:
484
+ # seperate out the batch and channels
485
+ pred_std = noise_pred.std([2, 3], keepdim=True)
486
+ norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
487
+ loss = loss + norm_std_loss
488
+
489
+
490
+ return loss
491
+
492
+ def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
493
+ return batch
494
+
495
+ def get_guided_loss(
496
+ self,
497
+ noisy_latents: torch.Tensor,
498
+ conditional_embeds: PromptEmbeds,
499
+ match_adapter_assist: bool,
500
+ network_weight_list: list,
501
+ timesteps: torch.Tensor,
502
+ pred_kwargs: dict,
503
+ batch: 'DataLoaderBatchDTO',
504
+ noise: torch.Tensor,
505
+ unconditional_embeds: Optional[PromptEmbeds] = None,
506
+ **kwargs
507
+ ):
508
+ loss = get_guidance_loss(
509
+ noisy_latents=noisy_latents,
510
+ conditional_embeds=conditional_embeds,
511
+ match_adapter_assist=match_adapter_assist,
512
+ network_weight_list=network_weight_list,
513
+ timesteps=timesteps,
514
+ pred_kwargs=pred_kwargs,
515
+ batch=batch,
516
+ noise=noise,
517
+ sd=self.sd,
518
+ unconditional_embeds=unconditional_embeds,
519
+ scaler=self.scaler,
520
+ **kwargs
521
+ )
522
+
523
+ return loss
524
+
525
+ def get_guided_loss_targeted_polarity(
526
+ self,
527
+ noisy_latents: torch.Tensor,
528
+ conditional_embeds: PromptEmbeds,
529
+ match_adapter_assist: bool,
530
+ network_weight_list: list,
531
+ timesteps: torch.Tensor,
532
+ pred_kwargs: dict,
533
+ batch: 'DataLoaderBatchDTO',
534
+ noise: torch.Tensor,
535
+ **kwargs
536
+ ):
537
+ with torch.no_grad():
538
+ # Perform targeted guidance (working title)
539
+ dtype = get_torch_dtype(self.train_config.dtype)
540
+
541
+ conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
542
+ unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
543
+
544
+ mean_latents = (conditional_latents + unconditional_latents) / 2.0
545
+
546
+ unconditional_diff = (unconditional_latents - mean_latents)
547
+ conditional_diff = (conditional_latents - mean_latents)
548
+
549
+ # we need to determine the amount of signal and noise that would be present at the current timestep
550
+ # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
551
+ # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
552
+ # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
553
+ # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
554
+ # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
555
+
556
+ # target_noise = noise + unconditional_signal
557
+
558
+ conditional_noisy_latents = self.sd.add_noise(
559
+ mean_latents,
560
+ noise,
561
+ timesteps
562
+ ).detach()
563
+
564
+ unconditional_noisy_latents = self.sd.add_noise(
565
+ mean_latents,
566
+ noise,
567
+ timesteps
568
+ ).detach()
569
+
570
+ # Disable the LoRA network so we can predict parent network knowledge without it
571
+ self.network.is_active = False
572
+ self.sd.unet.eval()
573
+
574
+ # Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
575
+ # This acts as our control to preserve the unaltered parts of the image.
576
+ baseline_prediction = self.sd.predict_noise(
577
+ latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
578
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
579
+ timestep=timesteps,
580
+ guidance_scale=1.0,
581
+ **pred_kwargs # adapter residuals in here
582
+ ).detach()
583
+
584
+ # double up everything to run it through all at once
585
+ cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
586
+ cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
587
+ cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
588
+
589
+ # since we are dividing the polarity from the middle out, we need to double our network
590
+ # weights on training since the convergent point will be at half network strength
591
+
592
+ negative_network_weights = [weight * -2.0 for weight in network_weight_list]
593
+ positive_network_weights = [weight * 2.0 for weight in network_weight_list]
594
+ cat_network_weight_list = positive_network_weights + negative_network_weights
595
+
596
+ # turn the LoRA network back on.
597
+ self.sd.unet.train()
598
+ self.network.is_active = True
599
+
600
+ self.network.multiplier = cat_network_weight_list
601
+
602
+ # do our prediction with LoRA active on the scaled guidance latents
603
+ prediction = self.sd.predict_noise(
604
+ latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
605
+ conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
606
+ timestep=cat_timesteps,
607
+ guidance_scale=1.0,
608
+ **pred_kwargs # adapter residuals in here
609
+ )
610
+
611
+ pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
612
+
613
+ pred_pos = pred_pos - baseline_prediction
614
+ pred_neg = pred_neg - baseline_prediction
615
+
616
+ pred_loss = torch.nn.functional.mse_loss(
617
+ pred_pos.float(),
618
+ unconditional_diff.float(),
619
+ reduction="none"
620
+ )
621
+ pred_loss = pred_loss.mean([1, 2, 3])
622
+
623
+ pred_neg_loss = torch.nn.functional.mse_loss(
624
+ pred_neg.float(),
625
+ conditional_diff.float(),
626
+ reduction="none"
627
+ )
628
+ pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
629
+
630
+ loss = (pred_loss + pred_neg_loss) / 2.0
631
+
632
+ # loss = self.apply_snr(loss, timesteps)
633
+ loss = loss.mean()
634
+ loss.backward()
635
+
636
+ # detach it so parent class can run backward on no grads without throwing error
637
+ loss = loss.detach()
638
+ loss.requires_grad_(True)
639
+
640
+ return loss
641
+
642
+ def get_guided_loss_masked_polarity(
643
+ self,
644
+ noisy_latents: torch.Tensor,
645
+ conditional_embeds: PromptEmbeds,
646
+ match_adapter_assist: bool,
647
+ network_weight_list: list,
648
+ timesteps: torch.Tensor,
649
+ pred_kwargs: dict,
650
+ batch: 'DataLoaderBatchDTO',
651
+ noise: torch.Tensor,
652
+ **kwargs
653
+ ):
654
+ with torch.no_grad():
655
+ # Perform targeted guidance (working title)
656
+ dtype = get_torch_dtype(self.train_config.dtype)
657
+
658
+ conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
659
+ unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
660
+ inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
661
+
662
+ mean_latents = (conditional_latents + unconditional_latents) / 2.0
663
+
664
+ # unconditional_diff = (unconditional_latents - mean_latents)
665
+ # conditional_diff = (conditional_latents - mean_latents)
666
+
667
+ # we need to determine the amount of signal and noise that would be present at the current timestep
668
+ # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
669
+ # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
670
+ # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
671
+ # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
672
+ # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
673
+
674
+ # make a differential mask
675
+ differential_mask = torch.abs(conditional_latents - unconditional_latents)
676
+ max_differential = \
677
+ differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
678
+ differential_scaler = 1.0 / max_differential
679
+ differential_mask = differential_mask * differential_scaler
680
+ spread_point = 0.1
681
+ # adjust mask to amplify the differential at 0.1
682
+ differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
683
+ # clip it
684
+ differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
685
+
686
+ # target_noise = noise + unconditional_signal
687
+
688
+ conditional_noisy_latents = self.sd.add_noise(
689
+ conditional_latents,
690
+ noise,
691
+ timesteps
692
+ ).detach()
693
+
694
+ unconditional_noisy_latents = self.sd.add_noise(
695
+ unconditional_latents,
696
+ noise,
697
+ timesteps
698
+ ).detach()
699
+
700
+ inverse_noisy_latents = self.sd.add_noise(
701
+ inverse_latents,
702
+ noise,
703
+ timesteps
704
+ ).detach()
705
+
706
+ # Disable the LoRA network so we can predict parent network knowledge without it
707
+ self.network.is_active = False
708
+ self.sd.unet.eval()
709
+
710
+ # Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
711
+ # This acts as our control to preserve the unaltered parts of the image.
712
+ # baseline_prediction = self.sd.predict_noise(
713
+ # latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
714
+ # conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
715
+ # timestep=timesteps,
716
+ # guidance_scale=1.0,
717
+ # **pred_kwargs # adapter residuals in here
718
+ # ).detach()
719
+
720
+ # double up everything to run it through all at once
721
+ cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
722
+ cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
723
+ cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
724
+
725
+ # since we are dividing the polarity from the middle out, we need to double our network
726
+ # weights on training since the convergent point will be at half network strength
727
+
728
+ negative_network_weights = [weight * -1.0 for weight in network_weight_list]
729
+ positive_network_weights = [weight * 1.0 for weight in network_weight_list]
730
+ cat_network_weight_list = positive_network_weights + negative_network_weights
731
+
732
+ # turn the LoRA network back on.
733
+ self.sd.unet.train()
734
+ self.network.is_active = True
735
+
736
+ self.network.multiplier = cat_network_weight_list
737
+
738
+ # do our prediction with LoRA active on the scaled guidance latents
739
+ prediction = self.sd.predict_noise(
740
+ latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
741
+ conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
742
+ timestep=cat_timesteps,
743
+ guidance_scale=1.0,
744
+ **pred_kwargs # adapter residuals in here
745
+ )
746
+
747
+ pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
748
+
749
+ # create a loss to balance the mean to 0 between the two predictions
750
+ differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
751
+
752
+ # pred_pos = pred_pos - baseline_prediction
753
+ # pred_neg = pred_neg - baseline_prediction
754
+
755
+ pred_loss = torch.nn.functional.mse_loss(
756
+ pred_pos.float(),
757
+ noise.float(),
758
+ reduction="none"
759
+ )
760
+ # apply mask
761
+ pred_loss = pred_loss * (1.0 + differential_mask)
762
+ pred_loss = pred_loss.mean([1, 2, 3])
763
+
764
+ pred_neg_loss = torch.nn.functional.mse_loss(
765
+ pred_neg.float(),
766
+ noise.float(),
767
+ reduction="none"
768
+ )
769
+ # apply inverse mask
770
+ pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
771
+ pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
772
+
773
+ # make a loss to balance to losses of the pos and neg so they are equal
774
+ # differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
775
+ #
776
+ # differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
777
+ #
778
+ # # add a multiplier to balancing losses to make them the top priority
779
+ # differential_mean_loss = differential_mean_loss
780
+
781
+ # remove the grads from the negative as it is only a balancing loss
782
+ # pred_neg_loss = pred_neg_loss.detach()
783
+
784
+ # loss = pred_loss + pred_neg_loss + differential_mean_loss
785
+ loss = pred_loss + pred_neg_loss
786
+
787
+ # loss = self.apply_snr(loss, timesteps)
788
+ loss = loss.mean()
789
+ loss.backward()
790
+
791
+ # detach it so parent class can run backward on no grads without throwing error
792
+ loss = loss.detach()
793
+ loss.requires_grad_(True)
794
+
795
+ return loss
796
+
797
+ def get_prior_prediction(
798
+ self,
799
+ noisy_latents: torch.Tensor,
800
+ conditional_embeds: PromptEmbeds,
801
+ match_adapter_assist: bool,
802
+ network_weight_list: list,
803
+ timesteps: torch.Tensor,
804
+ pred_kwargs: dict,
805
+ batch: 'DataLoaderBatchDTO',
806
+ noise: torch.Tensor,
807
+ unconditional_embeds: Optional[PromptEmbeds] = None,
808
+ conditioned_prompts=None,
809
+ **kwargs
810
+ ):
811
+ # todo for embeddings, we need to run without trigger words
812
+ was_unet_training = self.sd.unet.training
813
+ was_network_active = False
814
+ if self.network is not None:
815
+ was_network_active = self.network.is_active
816
+ self.network.is_active = False
817
+ can_disable_adapter = False
818
+ was_adapter_active = False
819
+ if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
820
+ isinstance(self.adapter, ReferenceAdapter) or
821
+ (isinstance(self.adapter, CustomAdapter))
822
+ ):
823
+ can_disable_adapter = True
824
+ was_adapter_active = self.adapter.is_active
825
+ self.adapter.is_active = False
826
+
827
+ if self.train_config.unload_text_encoder:
828
+ raise ValueError("Prior predictions currently do not support unloading text encoder")
829
+ # do a prediction here so we can match its output with network multiplier set to 0.0
830
+ with torch.no_grad():
831
+ dtype = get_torch_dtype(self.train_config.dtype)
832
+
833
+ embeds_to_use = conditional_embeds.clone().detach()
834
+ # handle clip vision adapter by removing triggers from prompt and replacing with the class name
835
+ if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
836
+ prompt_list = batch.get_caption_list()
837
+ class_name = ''
838
+
839
+ triggers = ['[trigger]', '[name]']
840
+ remove_tokens = []
841
+
842
+ if self.embed_config is not None:
843
+ triggers.append(self.embed_config.trigger)
844
+ for i in range(1, self.embed_config.tokens):
845
+ remove_tokens.append(f"{self.embed_config.trigger}_{i}")
846
+ if self.embed_config.trigger_class_name is not None:
847
+ class_name = self.embed_config.trigger_class_name
848
+
849
+ if self.adapter is not None:
850
+ triggers.append(self.adapter_config.trigger)
851
+ for i in range(1, self.adapter_config.num_tokens):
852
+ remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
853
+ if self.adapter_config.trigger_class_name is not None:
854
+ class_name = self.adapter_config.trigger_class_name
855
+
856
+ for idx, prompt in enumerate(prompt_list):
857
+ for remove_token in remove_tokens:
858
+ prompt = prompt.replace(remove_token, '')
859
+ for trigger in triggers:
860
+ prompt = prompt.replace(trigger, class_name)
861
+ prompt_list[idx] = prompt
862
+
863
+ embeds_to_use = self.sd.encode_prompt(
864
+ prompt_list,
865
+ long_prompts=self.do_long_prompts).to(
866
+ self.device_torch,
867
+ dtype=dtype).detach()
868
+
869
+ # dont use network on this
870
+ # self.network.multiplier = 0.0
871
+ self.sd.unet.eval()
872
+
873
+ if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
874
+ # we need to remove the image embeds from the prompt except for flux
875
+ embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
876
+ end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
877
+ embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
878
+ if unconditional_embeds is not None:
879
+ unconditional_embeds = unconditional_embeds.clone().detach()
880
+ unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
881
+
882
+ if unconditional_embeds is not None:
883
+ unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
884
+
885
+ prior_pred = self.sd.predict_noise(
886
+ latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
887
+ conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
888
+ unconditional_embeddings=unconditional_embeds,
889
+ timestep=timesteps,
890
+ guidance_scale=self.train_config.cfg_scale,
891
+ rescale_cfg=self.train_config.cfg_rescale,
892
+ **pred_kwargs # adapter residuals in here
893
+ )
894
+ if was_unet_training:
895
+ self.sd.unet.train()
896
+ prior_pred = prior_pred.detach()
897
+ # remove the residuals as we wont use them on prediction when matching control
898
+ if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
899
+ del pred_kwargs['down_intrablock_additional_residuals']
900
+ if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
901
+ del pred_kwargs['down_block_additional_residuals']
902
+ if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
903
+ del pred_kwargs['mid_block_additional_residual']
904
+
905
+ if can_disable_adapter:
906
+ self.adapter.is_active = was_adapter_active
907
+ # restore network
908
+ # self.network.multiplier = network_weight_list
909
+ if self.network is not None:
910
+ self.network.is_active = was_network_active
911
+ return prior_pred
912
+
913
+ def before_unet_predict(self):
914
+ pass
915
+
916
+ def after_unet_predict(self):
917
+ pass
918
+
919
+ def end_of_training_loop(self):
920
+ pass
921
+
922
+ def predict_noise(
923
+ self,
924
+ noisy_latents: torch.Tensor,
925
+ timesteps: Union[int, torch.Tensor] = 1,
926
+ conditional_embeds: Union[PromptEmbeds, None] = None,
927
+ unconditional_embeds: Union[PromptEmbeds, None] = None,
928
+ **kwargs,
929
+ ):
930
+ dtype = get_torch_dtype(self.train_config.dtype)
931
+ return self.sd.predict_noise(
932
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
933
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
934
+ unconditional_embeddings=unconditional_embeds,
935
+ timestep=timesteps,
936
+ guidance_scale=self.train_config.cfg_scale,
937
+ guidance_embedding_scale=self.train_config.cfg_scale,
938
+ detach_unconditional=False,
939
+ rescale_cfg=self.train_config.cfg_rescale,
940
+ bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
941
+ **kwargs
942
+ )
943
+
944
+ def train_single_accumulation(self, batch: DataLoaderBatchDTO):
945
+ self.timer.start('preprocess_batch')
946
+ batch = self.preprocess_batch(batch)
947
+ dtype = get_torch_dtype(self.train_config.dtype)
948
+ # sanity check
949
+ if self.sd.vae.dtype != self.sd.vae_torch_dtype:
950
+ self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
951
+ if isinstance(self.sd.text_encoder, list):
952
+ for encoder in self.sd.text_encoder:
953
+ if encoder.dtype != self.sd.te_torch_dtype:
954
+ encoder.to(self.sd.te_torch_dtype)
955
+ else:
956
+ if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
957
+ self.sd.text_encoder.to(self.sd.te_torch_dtype)
958
+
959
+ noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
960
+ if self.train_config.do_cfg or self.train_config.do_random_cfg:
961
+ # pick random negative prompts
962
+ if self.negative_prompt_pool is not None:
963
+ negative_prompts = []
964
+ for i in range(noisy_latents.shape[0]):
965
+ num_neg = random.randint(1, self.train_config.max_negative_prompts)
966
+ this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
967
+ this_neg_prompt = ', '.join(this_neg_prompts)
968
+ negative_prompts.append(this_neg_prompt)
969
+ self.batch_negative_prompt = negative_prompts
970
+ else:
971
+ self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
972
+
973
+ if self.adapter and isinstance(self.adapter, CustomAdapter):
974
+ # condition the prompt
975
+ # todo handle more than one adapter image
976
+ self.adapter.num_control_images = 1
977
+ conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
978
+
979
+ network_weight_list = batch.get_network_weight_list()
980
+ if self.train_config.single_item_batching:
981
+ network_weight_list = network_weight_list + network_weight_list
982
+
983
+ has_adapter_img = batch.control_tensor is not None
984
+ has_clip_image = batch.clip_image_tensor is not None
985
+ has_clip_image_embeds = batch.clip_image_embeds is not None
986
+ # force it to be true if doing regs as we handle those differently
987
+ if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
988
+ has_clip_image = True
989
+ if self._clip_image_embeds_unconditional is not None:
990
+ has_clip_image_embeds = True # we are caching embeds, handle that differently
991
+ has_clip_image = False
992
+
993
+ if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
994
+ raise ValueError(
995
+ "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
996
+
997
+ match_adapter_assist = False
998
+
999
+ # check if we are matching the adapter assistant
1000
+ if self.assistant_adapter:
1001
+ if self.train_config.match_adapter_chance == 1.0:
1002
+ match_adapter_assist = True
1003
+ elif self.train_config.match_adapter_chance > 0.0:
1004
+ match_adapter_assist = torch.rand(
1005
+ (1,), device=self.device_torch, dtype=dtype
1006
+ ) < self.train_config.match_adapter_chance
1007
+
1008
+ self.timer.stop('preprocess_batch')
1009
+
1010
+ is_reg = False
1011
+ with torch.no_grad():
1012
+ loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
1013
+ for idx, file_item in enumerate(batch.file_items):
1014
+ if file_item.is_reg:
1015
+ loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
1016
+ is_reg = True
1017
+
1018
+ adapter_images = None
1019
+ sigmas = None
1020
+ if has_adapter_img and (self.adapter or self.assistant_adapter):
1021
+ with self.timer('get_adapter_images'):
1022
+ # todo move this to data loader
1023
+ if batch.control_tensor is not None:
1024
+ adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
1025
+ # match in channels
1026
+ if self.assistant_adapter is not None:
1027
+ in_channels = self.assistant_adapter.config.in_channels
1028
+ if adapter_images.shape[1] != in_channels:
1029
+ # we need to match the channels
1030
+ adapter_images = adapter_images[:, :in_channels, :, :]
1031
+ else:
1032
+ raise NotImplementedError("Adapter images now must be loaded with dataloader")
1033
+
1034
+ clip_images = None
1035
+ if has_clip_image:
1036
+ with self.timer('get_clip_images'):
1037
+ # todo move this to data loader
1038
+ if batch.clip_image_tensor is not None:
1039
+ clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
1040
+
1041
+ mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
1042
+ if batch.mask_tensor is not None:
1043
+ with self.timer('get_mask_multiplier'):
1044
+ # upsampling no supported for bfloat16
1045
+ mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
1046
+ # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
1047
+ mask_multiplier = torch.nn.functional.interpolate(
1048
+ mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
1049
+ )
1050
+ # expand to match latents
1051
+ mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
1052
+ mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
1053
+
1054
+ def get_adapter_multiplier():
1055
+ if self.adapter and isinstance(self.adapter, T2IAdapter):
1056
+ # training a t2i adapter, not using as assistant.
1057
+ return 1.0
1058
+ elif match_adapter_assist:
1059
+ # training a texture. We want it high
1060
+ adapter_strength_min = 0.9
1061
+ adapter_strength_max = 1.0
1062
+ else:
1063
+ # training with assistance, we want it low
1064
+ # adapter_strength_min = 0.4
1065
+ # adapter_strength_max = 0.7
1066
+ adapter_strength_min = 0.5
1067
+ adapter_strength_max = 1.1
1068
+
1069
+ adapter_conditioning_scale = torch.rand(
1070
+ (1,), device=self.device_torch, dtype=dtype
1071
+ )
1072
+
1073
+ adapter_conditioning_scale = value_map(
1074
+ adapter_conditioning_scale,
1075
+ 0.0,
1076
+ 1.0,
1077
+ adapter_strength_min,
1078
+ adapter_strength_max
1079
+ )
1080
+ return adapter_conditioning_scale
1081
+
1082
+ # flush()
1083
+ with self.timer('grad_setup'):
1084
+
1085
+ # text encoding
1086
+ grad_on_text_encoder = False
1087
+ if self.train_config.train_text_encoder:
1088
+ grad_on_text_encoder = True
1089
+
1090
+ if self.embedding is not None:
1091
+ grad_on_text_encoder = True
1092
+
1093
+ if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
1094
+ grad_on_text_encoder = True
1095
+
1096
+ if self.adapter_config and self.adapter_config.type == 'te_augmenter':
1097
+ grad_on_text_encoder = True
1098
+
1099
+ # have a blank network so we can wrap it in a context and set multipliers without checking every time
1100
+ if self.network is not None:
1101
+ network = self.network
1102
+ else:
1103
+ network = BlankNetwork()
1104
+
1105
+ # set the weights
1106
+ network.multiplier = network_weight_list
1107
+
1108
+ # activate network if it exits
1109
+
1110
+ prompts_1 = conditioned_prompts
1111
+ prompts_2 = None
1112
+ if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
1113
+ prompts_1 = batch.get_caption_short_list()
1114
+ prompts_2 = conditioned_prompts
1115
+
1116
+ # make the batch splits
1117
+ if self.train_config.single_item_batching:
1118
+ if self.model_config.refiner_name_or_path is not None:
1119
+ raise ValueError("Single item batching is not supported when training the refiner")
1120
+ batch_size = noisy_latents.shape[0]
1121
+ # chunk/split everything
1122
+ noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
1123
+ noise_list = torch.chunk(noise, batch_size, dim=0)
1124
+ timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
1125
+ conditioned_prompts_list = [[prompt] for prompt in prompts_1]
1126
+ if imgs is not None:
1127
+ imgs_list = torch.chunk(imgs, batch_size, dim=0)
1128
+ else:
1129
+ imgs_list = [None for _ in range(batch_size)]
1130
+ if adapter_images is not None:
1131
+ adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
1132
+ else:
1133
+ adapter_images_list = [None for _ in range(batch_size)]
1134
+ if clip_images is not None:
1135
+ clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
1136
+ else:
1137
+ clip_images_list = [None for _ in range(batch_size)]
1138
+ mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
1139
+ if prompts_2 is None:
1140
+ prompt_2_list = [None for _ in range(batch_size)]
1141
+ else:
1142
+ prompt_2_list = [[prompt] for prompt in prompts_2]
1143
+
1144
+ else:
1145
+ noisy_latents_list = [noisy_latents]
1146
+ noise_list = [noise]
1147
+ timesteps_list = [timesteps]
1148
+ conditioned_prompts_list = [prompts_1]
1149
+ imgs_list = [imgs]
1150
+ adapter_images_list = [adapter_images]
1151
+ clip_images_list = [clip_images]
1152
+ mask_multiplier_list = [mask_multiplier]
1153
+ if prompts_2 is None:
1154
+ prompt_2_list = [None]
1155
+ else:
1156
+ prompt_2_list = [prompts_2]
1157
+
1158
+ for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
1159
+ noisy_latents_list,
1160
+ noise_list,
1161
+ timesteps_list,
1162
+ conditioned_prompts_list,
1163
+ imgs_list,
1164
+ adapter_images_list,
1165
+ clip_images_list,
1166
+ mask_multiplier_list,
1167
+ prompt_2_list
1168
+ ):
1169
+
1170
+ # if self.train_config.negative_prompt is not None:
1171
+ # # add negative prompt
1172
+ # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
1173
+ # range(len(conditioned_prompts))]
1174
+ # if prompt_2 is not None:
1175
+ # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
1176
+
1177
+ with (network):
1178
+ # encode clip adapter here so embeds are active for tokenizer
1179
+ if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
1180
+ with self.timer('encode_clip_vision_embeds'):
1181
+ if has_clip_image:
1182
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1183
+ clip_images.detach().to(self.device_torch, dtype=dtype),
1184
+ is_training=True,
1185
+ has_been_preprocessed=True
1186
+ )
1187
+ else:
1188
+ # just do a blank one
1189
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1190
+ torch.zeros(
1191
+ (noisy_latents.shape[0], 3, 512, 512),
1192
+ device=self.device_torch, dtype=dtype
1193
+ ),
1194
+ is_training=True,
1195
+ has_been_preprocessed=True,
1196
+ drop=True
1197
+ )
1198
+ # it will be injected into the tokenizer when called
1199
+ self.adapter(conditional_clip_embeds)
1200
+
1201
+ # do the custom adapter after the prior prediction
1202
+ if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg):
1203
+ quad_count = random.randint(1, 4)
1204
+ self.adapter.train()
1205
+ self.adapter.trigger_pre_te(
1206
+ tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise
1207
+ is_training=True,
1208
+ has_been_preprocessed=True,
1209
+ quad_count=quad_count,
1210
+ batch_size=noisy_latents.shape[0]
1211
+ )
1212
+
1213
+ with self.timer('encode_prompt'):
1214
+ unconditional_embeds = None
1215
+ if self.train_config.unload_text_encoder:
1216
+ with torch.set_grad_enabled(False):
1217
+ embeds_to_use = self.cached_blank_embeds.clone().detach().to(
1218
+ self.device_torch, dtype=dtype
1219
+ )
1220
+ if self.cached_trigger_embeds is not None and not is_reg:
1221
+ embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
1222
+ self.device_torch, dtype=dtype
1223
+ )
1224
+ conditional_embeds = concat_prompt_embeds(
1225
+ [embeds_to_use] * noisy_latents.shape[0]
1226
+ )
1227
+ if self.train_config.do_cfg:
1228
+ unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
1229
+ self.device_torch, dtype=dtype
1230
+ )
1231
+ unconditional_embeds = concat_prompt_embeds(
1232
+ [unconditional_embeds] * noisy_latents.shape[0]
1233
+ )
1234
+
1235
+ if isinstance(self.adapter, CustomAdapter):
1236
+ self.adapter.is_unconditional_run = False
1237
+
1238
+ elif grad_on_text_encoder:
1239
+ with torch.set_grad_enabled(True):
1240
+ if isinstance(self.adapter, CustomAdapter):
1241
+ self.adapter.is_unconditional_run = False
1242
+ conditional_embeds = self.sd.encode_prompt(
1243
+ conditioned_prompts, prompt_2,
1244
+ dropout_prob=self.train_config.prompt_dropout_prob,
1245
+ long_prompts=self.do_long_prompts).to(
1246
+ self.device_torch,
1247
+ dtype=dtype)
1248
+
1249
+ if self.train_config.do_cfg:
1250
+ if isinstance(self.adapter, CustomAdapter):
1251
+ self.adapter.is_unconditional_run = True
1252
+ # todo only do one and repeat it
1253
+ unconditional_embeds = self.sd.encode_prompt(
1254
+ self.batch_negative_prompt,
1255
+ self.batch_negative_prompt,
1256
+ dropout_prob=self.train_config.prompt_dropout_prob,
1257
+ long_prompts=self.do_long_prompts).to(
1258
+ self.device_torch,
1259
+ dtype=dtype)
1260
+ if isinstance(self.adapter, CustomAdapter):
1261
+ self.adapter.is_unconditional_run = False
1262
+ else:
1263
+ with torch.set_grad_enabled(False):
1264
+ # make sure it is in eval mode
1265
+ if isinstance(self.sd.text_encoder, list):
1266
+ for te in self.sd.text_encoder:
1267
+ te.eval()
1268
+ else:
1269
+ self.sd.text_encoder.eval()
1270
+ if isinstance(self.adapter, CustomAdapter):
1271
+ self.adapter.is_unconditional_run = False
1272
+ conditional_embeds = self.sd.encode_prompt(
1273
+ conditioned_prompts, prompt_2,
1274
+ dropout_prob=self.train_config.prompt_dropout_prob,
1275
+ long_prompts=self.do_long_prompts).to(
1276
+ self.device_torch,
1277
+ dtype=dtype)
1278
+ if self.train_config.do_cfg:
1279
+ if isinstance(self.adapter, CustomAdapter):
1280
+ self.adapter.is_unconditional_run = True
1281
+ unconditional_embeds = self.sd.encode_prompt(
1282
+ self.batch_negative_prompt,
1283
+ dropout_prob=self.train_config.prompt_dropout_prob,
1284
+ long_prompts=self.do_long_prompts).to(
1285
+ self.device_torch,
1286
+ dtype=dtype)
1287
+ if isinstance(self.adapter, CustomAdapter):
1288
+ self.adapter.is_unconditional_run = False
1289
+
1290
+ # detach the embeddings
1291
+ conditional_embeds = conditional_embeds.detach()
1292
+ if self.train_config.do_cfg:
1293
+ unconditional_embeds = unconditional_embeds.detach()
1294
+
1295
+ if self.decorator:
1296
+ conditional_embeds.text_embeds = self.decorator(
1297
+ conditional_embeds.text_embeds
1298
+ )
1299
+ if self.train_config.do_cfg:
1300
+ unconditional_embeds.text_embeds = self.decorator(
1301
+ unconditional_embeds.text_embeds,
1302
+ is_unconditional=True
1303
+ )
1304
+
1305
+ # flush()
1306
+ pred_kwargs = {}
1307
+
1308
+ if has_adapter_img:
1309
+ if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
1310
+ self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
1311
+ with torch.set_grad_enabled(self.adapter is not None):
1312
+ adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
1313
+ adapter_multiplier = get_adapter_multiplier()
1314
+ with self.timer('encode_adapter'):
1315
+ down_block_additional_residuals = adapter(adapter_images)
1316
+ if self.assistant_adapter:
1317
+ # not training. detach
1318
+ down_block_additional_residuals = [
1319
+ sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
1320
+ down_block_additional_residuals
1321
+ ]
1322
+ else:
1323
+ down_block_additional_residuals = [
1324
+ sample.to(dtype=dtype) * adapter_multiplier for sample in
1325
+ down_block_additional_residuals
1326
+ ]
1327
+
1328
+ pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
1329
+
1330
+ if self.adapter and isinstance(self.adapter, IPAdapter):
1331
+ with self.timer('encode_adapter_embeds'):
1332
+ # number of images to do if doing a quad image
1333
+ quad_count = random.randint(1, 4)
1334
+ image_size = self.adapter.input_size
1335
+ if has_clip_image_embeds:
1336
+ # todo handle reg images better than this
1337
+ if is_reg:
1338
+ # get unconditional image embeds from cache
1339
+ embeds = [
1340
+ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
1341
+ range(noisy_latents.shape[0])
1342
+ ]
1343
+ conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
1344
+ embeds,
1345
+ quad_count=quad_count
1346
+ )
1347
+
1348
+ if self.train_config.do_cfg:
1349
+ embeds = [
1350
+ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
1351
+ range(noisy_latents.shape[0])
1352
+ ]
1353
+ unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
1354
+ embeds,
1355
+ quad_count=quad_count
1356
+ )
1357
+
1358
+ else:
1359
+ conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
1360
+ batch.clip_image_embeds,
1361
+ quad_count=quad_count
1362
+ )
1363
+ if self.train_config.do_cfg:
1364
+ unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
1365
+ batch.clip_image_embeds_unconditional,
1366
+ quad_count=quad_count
1367
+ )
1368
+ elif is_reg:
1369
+ # we will zero it out in the img embedder
1370
+ clip_images = torch.zeros(
1371
+ (noisy_latents.shape[0], 3, image_size, image_size),
1372
+ device=self.device_torch, dtype=dtype
1373
+ ).detach()
1374
+ # drop will zero it out
1375
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1376
+ clip_images,
1377
+ drop=True,
1378
+ is_training=True,
1379
+ has_been_preprocessed=False,
1380
+ quad_count=quad_count
1381
+ )
1382
+ if self.train_config.do_cfg:
1383
+ unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1384
+ torch.zeros(
1385
+ (noisy_latents.shape[0], 3, image_size, image_size),
1386
+ device=self.device_torch, dtype=dtype
1387
+ ).detach(),
1388
+ is_training=True,
1389
+ drop=True,
1390
+ has_been_preprocessed=False,
1391
+ quad_count=quad_count
1392
+ )
1393
+ elif has_clip_image:
1394
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1395
+ clip_images.detach().to(self.device_torch, dtype=dtype),
1396
+ is_training=True,
1397
+ has_been_preprocessed=True,
1398
+ quad_count=quad_count,
1399
+ # do cfg on clip embeds to normalize the embeddings for when doing cfg
1400
+ # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
1401
+ # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
1402
+ )
1403
+ if self.train_config.do_cfg:
1404
+ unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
1405
+ clip_images.detach().to(self.device_torch, dtype=dtype),
1406
+ is_training=True,
1407
+ drop=True,
1408
+ has_been_preprocessed=True,
1409
+ quad_count=quad_count
1410
+ )
1411
+ else:
1412
+ print("No Clip Image")
1413
+ print([file_item.path for file_item in batch.file_items])
1414
+ raise ValueError("Could not find clip image")
1415
+
1416
+ if not self.adapter_config.train_image_encoder:
1417
+ # we are not training the image encoder, so we need to detach the embeds
1418
+ conditional_clip_embeds = conditional_clip_embeds.detach()
1419
+ if self.train_config.do_cfg:
1420
+ unconditional_clip_embeds = unconditional_clip_embeds.detach()
1421
+
1422
+ with self.timer('encode_adapter'):
1423
+ self.adapter.train()
1424
+ conditional_embeds = self.adapter(
1425
+ conditional_embeds.detach(),
1426
+ conditional_clip_embeds,
1427
+ is_unconditional=False
1428
+ )
1429
+ if self.train_config.do_cfg:
1430
+ unconditional_embeds = self.adapter(
1431
+ unconditional_embeds.detach(),
1432
+ unconditional_clip_embeds,
1433
+ is_unconditional=True
1434
+ )
1435
+ else:
1436
+ # wipe out unconsitional
1437
+ self.adapter.last_unconditional = None
1438
+
1439
+ if self.adapter and isinstance(self.adapter, ReferenceAdapter):
1440
+ # pass in our scheduler
1441
+ self.adapter.noise_scheduler = self.lr_scheduler
1442
+ if has_clip_image or has_adapter_img:
1443
+ img_to_use = clip_images if has_clip_image else adapter_images
1444
+ # currently 0-1 needs to be -1 to 1
1445
+ reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
1446
+ self.adapter.set_reference_images(reference_images)
1447
+ self.adapter.noise_scheduler = self.sd.noise_scheduler
1448
+ elif is_reg:
1449
+ self.adapter.set_blank_reference_images(noisy_latents.shape[0])
1450
+ else:
1451
+ self.adapter.set_reference_images(None)
1452
+
1453
+ prior_pred = None
1454
+
1455
+ do_reg_prior = False
1456
+ # if is_reg and (self.network is not None or self.adapter is not None):
1457
+ # # we are doing a reg image and we have a network or adapter
1458
+ # do_reg_prior = True
1459
+
1460
+ do_inverted_masked_prior = False
1461
+ if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
1462
+ do_inverted_masked_prior = True
1463
+
1464
+ do_correct_pred_norm_prior = self.train_config.correct_pred_norm
1465
+
1466
+ do_guidance_prior = False
1467
+
1468
+ if batch.unconditional_latents is not None:
1469
+ # for this not that, we need a prior pred to normalize
1470
+ guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
1471
+ if guidance_type == 'tnt':
1472
+ do_guidance_prior = True
1473
+
1474
+ if ((
1475
+ has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
1476
+ with self.timer('prior predict'):
1477
+ prior_pred = self.get_prior_prediction(
1478
+ noisy_latents=noisy_latents,
1479
+ conditional_embeds=conditional_embeds,
1480
+ match_adapter_assist=match_adapter_assist,
1481
+ network_weight_list=network_weight_list,
1482
+ timesteps=timesteps,
1483
+ pred_kwargs=pred_kwargs,
1484
+ noise=noise,
1485
+ batch=batch,
1486
+ unconditional_embeds=unconditional_embeds,
1487
+ conditioned_prompts=conditioned_prompts
1488
+ )
1489
+ if prior_pred is not None:
1490
+ prior_pred = prior_pred.detach()
1491
+
1492
+ # do the custom adapter after the prior prediction
1493
+ if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
1494
+ quad_count = random.randint(1, 4)
1495
+ self.adapter.train()
1496
+ conditional_embeds = self.adapter.condition_encoded_embeds(
1497
+ tensors_0_1=clip_images,
1498
+ prompt_embeds=conditional_embeds,
1499
+ is_training=True,
1500
+ has_been_preprocessed=True,
1501
+ quad_count=quad_count
1502
+ )
1503
+ if self.train_config.do_cfg and unconditional_embeds is not None:
1504
+ unconditional_embeds = self.adapter.condition_encoded_embeds(
1505
+ tensors_0_1=clip_images,
1506
+ prompt_embeds=unconditional_embeds,
1507
+ is_training=True,
1508
+ has_been_preprocessed=True,
1509
+ is_unconditional=True,
1510
+ quad_count=quad_count
1511
+ )
1512
+
1513
+ if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None:
1514
+ self.adapter.add_extra_values(batch.extra_values.detach())
1515
+
1516
+ if self.train_config.do_cfg:
1517
+ self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
1518
+ is_unconditional=True)
1519
+
1520
+ if has_adapter_img:
1521
+ if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
1522
+ self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
1523
+ if self.train_config.do_cfg:
1524
+ raise ValueError("ControlNetModel is not supported with CFG")
1525
+ with torch.set_grad_enabled(self.adapter is not None):
1526
+ adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
1527
+ adapter_multiplier = get_adapter_multiplier()
1528
+ with self.timer('encode_adapter'):
1529
+ # add_text_embeds is pooled_prompt_embeds for sdxl
1530
+ added_cond_kwargs = {}
1531
+ if self.sd.is_xl:
1532
+ added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
1533
+ added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
1534
+ down_block_res_samples, mid_block_res_sample = adapter(
1535
+ noisy_latents,
1536
+ timesteps,
1537
+ encoder_hidden_states=conditional_embeds.text_embeds,
1538
+ controlnet_cond=adapter_images,
1539
+ conditioning_scale=1.0,
1540
+ guess_mode=False,
1541
+ added_cond_kwargs=added_cond_kwargs,
1542
+ return_dict=False,
1543
+ )
1544
+ pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
1545
+ pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
1546
+
1547
+ self.before_unet_predict()
1548
+ # do a prior pred if we have an unconditional image, we will swap out the giadance later
1549
+ if batch.unconditional_latents is not None or self.do_guided_loss:
1550
+ # do guided loss
1551
+ loss = self.get_guided_loss(
1552
+ noisy_latents=noisy_latents,
1553
+ conditional_embeds=conditional_embeds,
1554
+ match_adapter_assist=match_adapter_assist,
1555
+ network_weight_list=network_weight_list,
1556
+ timesteps=timesteps,
1557
+ pred_kwargs=pred_kwargs,
1558
+ batch=batch,
1559
+ noise=noise,
1560
+ unconditional_embeds=unconditional_embeds,
1561
+ mask_multiplier=mask_multiplier,
1562
+ prior_pred=prior_pred,
1563
+ )
1564
+
1565
+ else:
1566
+ with self.timer('predict_unet'):
1567
+ if unconditional_embeds is not None:
1568
+ unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
1569
+ noise_pred = self.predict_noise(
1570
+ noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
1571
+ timesteps=timesteps,
1572
+ conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
1573
+ unconditional_embeds=unconditional_embeds,
1574
+ **pred_kwargs
1575
+ )
1576
+ self.after_unet_predict()
1577
+
1578
+ with self.timer('calculate_loss'):
1579
+ noise = noise.to(self.device_torch, dtype=dtype).detach()
1580
+ loss = self.calculate_loss(
1581
+ noise_pred=noise_pred,
1582
+ noise=noise,
1583
+ noisy_latents=noisy_latents,
1584
+ timesteps=timesteps,
1585
+ batch=batch,
1586
+ mask_multiplier=mask_multiplier,
1587
+ prior_pred=prior_pred,
1588
+ )
1589
+ # check if nan
1590
+ if torch.isnan(loss):
1591
+ print("loss is nan")
1592
+ loss = torch.zeros_like(loss).requires_grad_(True)
1593
+
1594
+ with self.timer('backward'):
1595
+ # todo we have multiplier seperated. works for now as res are not in same batch, but need to change
1596
+ loss = loss * loss_multiplier.mean()
1597
+ # IMPORTANT if gradient checkpointing do not leave with network when doing backward
1598
+ # it will destroy the gradients. This is because the network is a context manager
1599
+ # and will change the multipliers back to 0.0 when exiting. They will be
1600
+ # 0.0 for the backward pass and the gradients will be 0.0
1601
+ # I spent weeks on fighting this. DON'T DO IT
1602
+ # with fsdp_overlap_step_with_backward():
1603
+ # if self.is_bfloat:
1604
+ # loss.backward()
1605
+ # else:
1606
+ if not self.do_grad_scale:
1607
+ loss.backward()
1608
+ else:
1609
+ self.scaler.scale(loss).backward()
1610
+
1611
+ return loss.detach()
1612
+ # flush()
1613
+
1614
+ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
1615
+ if isinstance(batch, list):
1616
+ batch_list = batch
1617
+ else:
1618
+ batch_list = [batch]
1619
+ total_loss = None
1620
+ self.optimizer.zero_grad()
1621
+ for batch in batch_list:
1622
+ loss = self.train_single_accumulation(batch)
1623
+ if total_loss is None:
1624
+ total_loss = loss
1625
+ else:
1626
+ total_loss += loss
1627
+ if len(batch_list) > 1 and self.model_config.low_vram:
1628
+ torch.cuda.empty_cache()
1629
+
1630
+
1631
+ if not self.is_grad_accumulation_step:
1632
+ # fix this for multi params
1633
+ if self.train_config.optimizer != 'adafactor':
1634
+ if self.do_grad_scale:
1635
+ self.scaler.unscale_(self.optimizer)
1636
+ if isinstance(self.params[0], dict):
1637
+ for i in range(len(self.params)):
1638
+ torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
1639
+ else:
1640
+ torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
1641
+ # only step if we are not accumulating
1642
+ with self.timer('optimizer_step'):
1643
+ # self.optimizer.step()
1644
+ if not self.do_grad_scale:
1645
+ self.optimizer.step()
1646
+ else:
1647
+ self.scaler.step(self.optimizer)
1648
+ self.scaler.update()
1649
+
1650
+ self.optimizer.zero_grad(set_to_none=True)
1651
+ if self.adapter and isinstance(self.adapter, CustomAdapter):
1652
+ self.adapter.post_weight_update()
1653
+ if self.ema is not None:
1654
+ with self.timer('ema_update'):
1655
+ self.ema.update()
1656
+ else:
1657
+ # gradient accumulation. Just a place for breakpoint
1658
+ pass
1659
+
1660
+ # TODO Should we only step scheduler on grad step? If so, need to recalculate last step
1661
+ with self.timer('scheduler_step'):
1662
+ self.lr_scheduler.step()
1663
+
1664
+ if self.embedding is not None:
1665
+ with self.timer('restore_embeddings'):
1666
+ # Let's make sure we don't update any embedding weights besides the newly added token
1667
+ self.embedding.restore_embeddings()
1668
+ if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
1669
+ with self.timer('restore_adapter'):
1670
+ # Let's make sure we don't update any embedding weights besides the newly added token
1671
+ self.adapter.restore_embeddings()
1672
+
1673
+ loss_dict = OrderedDict(
1674
+ {'loss': loss.item()}
1675
+ )
1676
+
1677
+ self.end_of_training_loop()
1678
+
1679
+ return loss_dict
extensions_built_in/sd_trainer/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class SDTrainerExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "sd_trainer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "SD Trainer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .SDTrainer import SDTrainer
19
+ return SDTrainer
20
+
21
+
22
+ # for backwards compatability
23
+ class TextualInversionTrainer(SDTrainerExtension):
24
+ uid = "textual_inversion_trainer"
25
+
26
+
27
+ AI_TOOLKIT_EXTENSIONS = [
28
+ # you can put a list of extensions here
29
+ SDTrainerExtension, TextualInversionTrainer
30
+ ]
extensions_built_in/sd_trainer/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from collections import OrderedDict
4
+ import os
5
+ from contextlib import nullcontext
6
+ from typing import Optional, Union, List
7
+ from torch.utils.data import ConcatDataset, DataLoader
8
+
9
+ from toolkit.config_modules import ReferenceDatasetConfig
10
+ from toolkit.data_loader import PairedImageDataset
11
+ from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds, build_latent_image_batch_for_prompt_pair
12
+ from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
13
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
14
+ import gc
15
+ from toolkit import train_tools
16
+ import torch
17
+ from jobs.process import BaseSDTrainProcess
18
+ import random
19
+
20
+ import random
21
+ from collections import OrderedDict
22
+ from tqdm import tqdm
23
+
24
+ from toolkit.config_modules import SliderConfig
25
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
26
+ import gc
27
+ from toolkit import train_tools
28
+ from toolkit.prompt_utils import \
29
+ EncodedPromptPair, ACTION_TYPES_SLIDER, \
30
+ EncodedAnchor, concat_prompt_pairs, \
31
+ concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \
32
+ split_prompt_pairs
33
+
34
+ import torch
35
+
36
+
37
+ def flush():
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+
42
+ class UltimateSliderConfig(SliderConfig):
43
+ def __init__(self, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.additional_losses: List[str] = kwargs.get('additional_losses', [])
46
+ self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
47
+ self.img_loss_weight: float = kwargs.get('img_loss_weight', 1.0)
48
+ self.cfg_loss_weight: float = kwargs.get('cfg_loss_weight', 1.0)
49
+ self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
50
+
51
+
52
+ class UltimateSliderTrainerProcess(BaseSDTrainProcess):
53
+ sd: StableDiffusion
54
+ data_loader: DataLoader = None
55
+
56
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
57
+ super().__init__(process_id, job, config, **kwargs)
58
+ self.prompt_txt_list = None
59
+ self.step_num = 0
60
+ self.start_step = 0
61
+ self.device = self.get_conf('device', self.job.device)
62
+ self.device_torch = torch.device(self.device)
63
+ self.slider_config = UltimateSliderConfig(**self.get_conf('slider', {}))
64
+
65
+ self.prompt_cache = PromptEmbedsCache()
66
+ self.prompt_pairs: list[EncodedPromptPair] = []
67
+ self.anchor_pairs: list[EncodedAnchor] = []
68
+ # keep track of prompt chunk size
69
+ self.prompt_chunk_size = 1
70
+
71
+ # store a list of all the prompts from the dataset so we can cache it
72
+ self.dataset_prompts = []
73
+ self.train_with_dataset = self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0
74
+
75
+ def load_datasets(self):
76
+ if self.data_loader is None and \
77
+ self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0:
78
+ print(f"Loading datasets")
79
+ datasets = []
80
+ for dataset in self.slider_config.datasets:
81
+ print(f" - Dataset: {dataset.pair_folder}")
82
+ config = {
83
+ 'path': dataset.pair_folder,
84
+ 'size': dataset.size,
85
+ 'default_prompt': dataset.target_class,
86
+ 'network_weight': dataset.network_weight,
87
+ 'pos_weight': dataset.pos_weight,
88
+ 'neg_weight': dataset.neg_weight,
89
+ 'pos_folder': dataset.pos_folder,
90
+ 'neg_folder': dataset.neg_folder,
91
+ }
92
+ image_dataset = PairedImageDataset(config)
93
+ datasets.append(image_dataset)
94
+
95
+ # capture all the prompts from it so we can cache the embeds
96
+ self.dataset_prompts += image_dataset.get_all_prompts()
97
+
98
+ concatenated_dataset = ConcatDataset(datasets)
99
+ self.data_loader = DataLoader(
100
+ concatenated_dataset,
101
+ batch_size=self.train_config.batch_size,
102
+ shuffle=True,
103
+ num_workers=2
104
+ )
105
+
106
+ def before_model_load(self):
107
+ pass
108
+
109
+ def hook_before_train_loop(self):
110
+ # load any datasets if they were passed
111
+ self.load_datasets()
112
+
113
+ # read line by line from file
114
+ if self.slider_config.prompt_file:
115
+ self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
116
+ with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
117
+ self.prompt_txt_list = f.readlines()
118
+ # clean empty lines
119
+ self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
120
+
121
+ self.print(f"Found {len(self.prompt_txt_list)} prompts.")
122
+
123
+ if not self.slider_config.prompt_tensors:
124
+ print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
125
+ # shuffle
126
+ random.shuffle(self.prompt_txt_list)
127
+ # trim to max steps
128
+ self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
129
+ # trim list to our max steps
130
+
131
+ cache = PromptEmbedsCache()
132
+
133
+ # get encoded latents for our prompts
134
+ with torch.no_grad():
135
+ # list of neutrals. Can come from file or be empty
136
+ neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
137
+
138
+ # build the prompts to cache
139
+ prompts_to_cache = []
140
+ for neutral in neutral_list:
141
+ for target in self.slider_config.targets:
142
+ prompt_list = [
143
+ f"{target.target_class}", # target_class
144
+ f"{target.target_class} {neutral}", # target_class with neutral
145
+ f"{target.positive}", # positive_target
146
+ f"{target.positive} {neutral}", # positive_target with neutral
147
+ f"{target.negative}", # negative_target
148
+ f"{target.negative} {neutral}", # negative_target with neutral
149
+ f"{neutral}", # neutral
150
+ f"{target.positive} {target.negative}", # both targets
151
+ f"{target.negative} {target.positive}", # both targets reverse
152
+ ]
153
+ prompts_to_cache += prompt_list
154
+
155
+ # remove duplicates
156
+ prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
157
+
158
+ # trim to max steps if max steps is lower than prompt count
159
+ prompts_to_cache = prompts_to_cache[:self.train_config.steps]
160
+
161
+ if len(self.dataset_prompts) > 0:
162
+ # add the prompts from the dataset
163
+ prompts_to_cache += self.dataset_prompts
164
+
165
+ # encode them
166
+ cache = encode_prompts_to_cache(
167
+ prompt_list=prompts_to_cache,
168
+ sd=self.sd,
169
+ cache=cache,
170
+ prompt_tensor_file=self.slider_config.prompt_tensors
171
+ )
172
+
173
+ prompt_pairs = []
174
+ prompt_batches = []
175
+ for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False):
176
+ for target in self.slider_config.targets:
177
+ prompt_pair_batch = build_prompt_pair_batch_from_cache(
178
+ cache=cache,
179
+ target=target,
180
+ neutral=neutral,
181
+
182
+ )
183
+ if self.slider_config.batch_full_slide:
184
+ # concat the prompt pairs
185
+ # this allows us to run the entire 4 part process in one shot (for slider)
186
+ self.prompt_chunk_size = 4
187
+ concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu')
188
+ prompt_pairs += [concat_prompt_pair_batch]
189
+ else:
190
+ self.prompt_chunk_size = 1
191
+ # do them one at a time (probably not necessary after new optimizations)
192
+ prompt_pairs += [x.to('cpu') for x in prompt_pair_batch]
193
+
194
+ # move to cpu to save vram
195
+ # We don't need text encoder anymore, but keep it on cpu for sampling
196
+ # if text encoder is list
197
+ if isinstance(self.sd.text_encoder, list):
198
+ for encoder in self.sd.text_encoder:
199
+ encoder.to("cpu")
200
+ else:
201
+ self.sd.text_encoder.to("cpu")
202
+ self.prompt_cache = cache
203
+ self.prompt_pairs = prompt_pairs
204
+ # end hook_before_train_loop
205
+
206
+ # move vae to device so we can encode on the fly
207
+ # todo cache latents
208
+ self.sd.vae.to(self.device_torch)
209
+ self.sd.vae.eval()
210
+ self.sd.vae.requires_grad_(False)
211
+
212
+ if self.train_config.gradient_checkpointing:
213
+ # may get disabled elsewhere
214
+ self.sd.unet.enable_gradient_checkpointing()
215
+
216
+ flush()
217
+ # end hook_before_train_loop
218
+
219
+ def hook_train_loop(self, batch):
220
+ dtype = get_torch_dtype(self.train_config.dtype)
221
+
222
+ with torch.no_grad():
223
+ ### LOOP SETUP ###
224
+ noise_scheduler = self.sd.noise_scheduler
225
+ optimizer = self.optimizer
226
+ lr_scheduler = self.lr_scheduler
227
+
228
+ ### TARGET_PROMPTS ###
229
+ # get a random pair
230
+ prompt_pair: EncodedPromptPair = self.prompt_pairs[
231
+ torch.randint(0, len(self.prompt_pairs), (1,)).item()
232
+ ]
233
+ # move to device and dtype
234
+ prompt_pair.to(self.device_torch, dtype=dtype)
235
+
236
+ ### PREP REFERENCE IMAGES ###
237
+
238
+ imgs, prompts, network_weights = batch
239
+ network_pos_weight, network_neg_weight = network_weights
240
+
241
+ if isinstance(network_pos_weight, torch.Tensor):
242
+ network_pos_weight = network_pos_weight.item()
243
+ if isinstance(network_neg_weight, torch.Tensor):
244
+ network_neg_weight = network_neg_weight.item()
245
+
246
+ # get an array of random floats between -weight_jitter and weight_jitter
247
+ weight_jitter = self.slider_config.weight_jitter
248
+ if weight_jitter > 0.0:
249
+ jitter_list = random.uniform(-weight_jitter, weight_jitter)
250
+ network_pos_weight += jitter_list
251
+ network_neg_weight += (jitter_list * -1.0)
252
+
253
+ # if items in network_weight list are tensors, convert them to floats
254
+ imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
255
+ # split batched images in half so left is negative and right is positive
256
+ negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
257
+
258
+ height = positive_images.shape[2]
259
+ width = positive_images.shape[3]
260
+ batch_size = positive_images.shape[0]
261
+
262
+ positive_latents = self.sd.encode_images(positive_images)
263
+ negative_latents = self.sd.encode_images(negative_images)
264
+
265
+ self.sd.noise_scheduler.set_timesteps(
266
+ self.train_config.max_denoising_steps, device=self.device_torch
267
+ )
268
+
269
+ timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
270
+ current_timestep_index = timesteps.item()
271
+ current_timestep = noise_scheduler.timesteps[current_timestep_index]
272
+ timesteps = timesteps.long()
273
+
274
+ # get noise
275
+ noise_positive = self.sd.get_latent_noise(
276
+ pixel_height=height,
277
+ pixel_width=width,
278
+ batch_size=batch_size,
279
+ noise_offset=self.train_config.noise_offset,
280
+ ).to(self.device_torch, dtype=dtype)
281
+
282
+ noise_negative = noise_positive.clone()
283
+
284
+ # Add noise to the latents according to the noise magnitude at each timestep
285
+ # (this is the forward diffusion process)
286
+ noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
287
+ noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
288
+
289
+ ### CFG SLIDER TRAINING PREP ###
290
+
291
+ # get CFG txt latents
292
+ noisy_cfg_latents = build_latent_image_batch_for_prompt_pair(
293
+ pos_latent=noisy_positive_latents,
294
+ neg_latent=noisy_negative_latents,
295
+ prompt_pair=prompt_pair,
296
+ prompt_chunk_size=self.prompt_chunk_size,
297
+ )
298
+ noisy_cfg_latents.requires_grad = False
299
+
300
+ assert not self.network.is_active
301
+
302
+ # 4.20 GB RAM for 512x512
303
+ positive_latents = self.sd.predict_noise(
304
+ latents=noisy_cfg_latents,
305
+ text_embeddings=train_tools.concat_prompt_embeddings(
306
+ prompt_pair.positive_target, # negative prompt
307
+ prompt_pair.negative_target, # positive prompt
308
+ self.train_config.batch_size,
309
+ ),
310
+ timestep=current_timestep,
311
+ guidance_scale=1.0
312
+ )
313
+ positive_latents.requires_grad = False
314
+
315
+ neutral_latents = self.sd.predict_noise(
316
+ latents=noisy_cfg_latents,
317
+ text_embeddings=train_tools.concat_prompt_embeddings(
318
+ prompt_pair.positive_target, # negative prompt
319
+ prompt_pair.empty_prompt, # positive prompt (normally neutral
320
+ self.train_config.batch_size,
321
+ ),
322
+ timestep=current_timestep,
323
+ guidance_scale=1.0
324
+ )
325
+ neutral_latents.requires_grad = False
326
+
327
+ unconditional_latents = self.sd.predict_noise(
328
+ latents=noisy_cfg_latents,
329
+ text_embeddings=train_tools.concat_prompt_embeddings(
330
+ prompt_pair.positive_target, # negative prompt
331
+ prompt_pair.positive_target, # positive prompt
332
+ self.train_config.batch_size,
333
+ ),
334
+ timestep=current_timestep,
335
+ guidance_scale=1.0
336
+ )
337
+ unconditional_latents.requires_grad = False
338
+
339
+ positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
340
+ neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
341
+ unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
342
+ prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
343
+ noisy_cfg_latents_chunks = torch.chunk(noisy_cfg_latents, self.prompt_chunk_size, dim=0)
344
+ assert len(prompt_pair_chunks) == len(noisy_cfg_latents_chunks)
345
+
346
+ noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
347
+ noise = torch.cat([noise_positive, noise_negative], dim=0)
348
+ timesteps = torch.cat([timesteps, timesteps], dim=0)
349
+ network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
350
+
351
+ flush()
352
+
353
+ loss_float = None
354
+ loss_mirror_float = None
355
+
356
+ self.optimizer.zero_grad()
357
+ noisy_latents.requires_grad = False
358
+
359
+ # TODO allow both processed to train text encoder, for now, we just to unet and cache all text encodes
360
+ # if training text encoder enable grads, else do context of no grad
361
+ # with torch.set_grad_enabled(self.train_config.train_text_encoder):
362
+ # # text encoding
363
+ # embedding_list = []
364
+ # # embed the prompts
365
+ # for prompt in prompts:
366
+ # embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
367
+ # embedding_list.append(embedding)
368
+ # conditional_embeds = concat_prompt_embeds(embedding_list)
369
+ # conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
370
+
371
+ if self.train_with_dataset:
372
+ embedding_list = []
373
+ with torch.set_grad_enabled(self.train_config.train_text_encoder):
374
+ for prompt in prompts:
375
+ # get embedding form cache
376
+ embedding = self.prompt_cache[prompt]
377
+ embedding = embedding.to(self.device_torch, dtype=dtype)
378
+ embedding_list.append(embedding)
379
+ conditional_embeds = concat_prompt_embeds(embedding_list)
380
+ # double up so we can do both sides of the slider
381
+ conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
382
+ else:
383
+ # throw error. Not supported yet
384
+ raise Exception("Datasets and targets required for ultimate slider")
385
+
386
+ if self.model_config.is_xl:
387
+ # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
388
+ network_multiplier_list = network_multiplier
389
+ noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
390
+ noise_list = torch.chunk(noise, 2, dim=0)
391
+ timesteps_list = torch.chunk(timesteps, 2, dim=0)
392
+ conditional_embeds_list = split_prompt_embeds(conditional_embeds)
393
+ else:
394
+ network_multiplier_list = [network_multiplier]
395
+ noisy_latent_list = [noisy_latents]
396
+ noise_list = [noise]
397
+ timesteps_list = [timesteps]
398
+ conditional_embeds_list = [conditional_embeds]
399
+
400
+ ## DO REFERENCE IMAGE TRAINING ##
401
+
402
+ reference_image_losses = []
403
+ # allow to chunk it out to save vram
404
+ for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip(
405
+ network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list
406
+ ):
407
+ with self.network:
408
+ assert self.network.is_active
409
+
410
+ self.network.multiplier = network_multiplier
411
+
412
+ noise_pred = self.sd.predict_noise(
413
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
414
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
415
+ timestep=timesteps,
416
+ )
417
+ noise = noise.to(self.device_torch, dtype=dtype)
418
+
419
+ if self.sd.prediction_type == 'v_prediction':
420
+ # v-parameterization training
421
+ target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
422
+ else:
423
+ target = noise
424
+
425
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
426
+ loss = loss.mean([1, 2, 3])
427
+
428
+ # todo add snr gamma here
429
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
430
+ # add min_snr_gamma
431
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
432
+
433
+ loss = loss.mean()
434
+ loss = loss * self.slider_config.img_loss_weight
435
+ loss_slide_float = loss.item()
436
+
437
+ loss_float = loss.item()
438
+ reference_image_losses.append(loss_float)
439
+
440
+ # back propagate loss to free ram
441
+ loss.backward()
442
+ flush()
443
+
444
+ ## DO CFG SLIDER TRAINING ##
445
+
446
+ cfg_loss_list = []
447
+
448
+ with self.network:
449
+ assert self.network.is_active
450
+ for prompt_pair_chunk, \
451
+ noisy_cfg_latent_chunk, \
452
+ positive_latents_chunk, \
453
+ neutral_latents_chunk, \
454
+ unconditional_latents_chunk \
455
+ in zip(
456
+ prompt_pair_chunks,
457
+ noisy_cfg_latents_chunks,
458
+ positive_latents_chunks,
459
+ neutral_latents_chunks,
460
+ unconditional_latents_chunks,
461
+ ):
462
+ self.network.multiplier = prompt_pair_chunk.multiplier_list
463
+
464
+ target_latents = self.sd.predict_noise(
465
+ latents=noisy_cfg_latent_chunk,
466
+ text_embeddings=train_tools.concat_prompt_embeddings(
467
+ prompt_pair_chunk.positive_target, # negative prompt
468
+ prompt_pair_chunk.target_class, # positive prompt
469
+ self.train_config.batch_size,
470
+ ),
471
+ timestep=current_timestep,
472
+ guidance_scale=1.0
473
+ )
474
+
475
+ guidance_scale = 1.0
476
+
477
+ offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk)
478
+
479
+ # make offset multiplier based on actions
480
+ offset_multiplier_list = []
481
+ for action in prompt_pair_chunk.action_list:
482
+ if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE:
483
+ offset_multiplier_list += [-1.0]
484
+ elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE:
485
+ offset_multiplier_list += [1.0]
486
+
487
+ offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype)
488
+ # make offset multiplier match rank of offset
489
+ offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1)
490
+ offset *= offset_multiplier
491
+
492
+ offset_neutral = neutral_latents_chunk
493
+ # offsets are already adjusted on a per-batch basis
494
+ offset_neutral += offset
495
+
496
+ # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
497
+ loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
498
+ loss = loss.mean([1, 2, 3])
499
+
500
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
501
+ # match batch size
502
+ timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
503
+ # add min_snr_gamma
504
+ loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
505
+ self.train_config.min_snr_gamma)
506
+
507
+ loss = loss.mean() * prompt_pair_chunk.weight * self.slider_config.cfg_loss_weight
508
+
509
+ loss.backward()
510
+ cfg_loss_list.append(loss.item())
511
+ del target_latents
512
+ del offset_neutral
513
+ del loss
514
+ flush()
515
+
516
+ # apply gradients
517
+ optimizer.step()
518
+ lr_scheduler.step()
519
+
520
+ # reset network
521
+ self.network.multiplier = 1.0
522
+
523
+ reference_image_loss = sum(reference_image_losses) / len(reference_image_losses) if len(
524
+ reference_image_losses) > 0 else 0.0
525
+ cfg_loss = sum(cfg_loss_list) / len(cfg_loss_list) if len(cfg_loss_list) > 0 else 0.0
526
+
527
+ loss_dict = OrderedDict({
528
+ 'loss/img': reference_image_loss,
529
+ 'loss/cfg': cfg_loss,
530
+ })
531
+
532
+ return loss_dict
533
+ # end hook_train_loop
extensions_built_in/ultimate_slider_trainer/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class UltimateSliderTrainer(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "ultimate_slider_trainer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Ultimate Slider Trainer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess
19
+ return UltimateSliderTrainerProcess
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ UltimateSliderTrainer
25
+ ]
extensions_built_in/ultimate_slider_trainer/config/train.example.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: example_name
5
+ process:
6
+ - type: 'image_reference_slider_trainer'
7
+ training_folder: "/mnt/Train/out/LoRA"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "/home/jaret/Dev/.tensorboard"
11
+ network:
12
+ type: "lora"
13
+ linear: 8
14
+ linear_alpha: 8
15
+ train:
16
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
17
+ steps: 5000
18
+ lr: 1e-4
19
+ train_unet: true
20
+ gradient_checkpointing: true
21
+ train_text_encoder: true
22
+ optimizer: "adamw"
23
+ optimizer_params:
24
+ weight_decay: 1e-2
25
+ lr_scheduler: "constant"
26
+ max_denoising_steps: 1000
27
+ batch_size: 1
28
+ dtype: bf16
29
+ xformers: true
30
+ skip_first_sample: true
31
+ noise_offset: 0.0
32
+ model:
33
+ name_or_path: "/path/to/model.safetensors"
34
+ is_v2: false # for v2 models
35
+ is_xl: false # for SDXL models
36
+ is_v_pred: false # for v-prediction models (most v2 models)
37
+ save:
38
+ dtype: float16 # precision to save
39
+ save_every: 1000 # save every this many steps
40
+ max_step_saves_to_keep: 2 # only affects step counts
41
+ sample:
42
+ sampler: "ddpm" # must match train.noise_scheduler
43
+ sample_every: 100 # sample every this many steps
44
+ width: 512
45
+ height: 512
46
+ prompts:
47
+ - "photo of a woman with red hair taking a selfie --m -3"
48
+ - "photo of a woman with red hair taking a selfie --m -1"
49
+ - "photo of a woman with red hair taking a selfie --m 1"
50
+ - "photo of a woman with red hair taking a selfie --m 3"
51
+ - "close up photo of a man smiling at the camera, in a tank top --m -3"
52
+ - "close up photo of a man smiling at the camera, in a tank top--m -1"
53
+ - "close up photo of a man smiling at the camera, in a tank top --m 1"
54
+ - "close up photo of a man smiling at the camera, in a tank top --m 3"
55
+ - "photo of a blonde woman smiling, barista --m -3"
56
+ - "photo of a blonde woman smiling, barista --m -1"
57
+ - "photo of a blonde woman smiling, barista --m 1"
58
+ - "photo of a blonde woman smiling, barista --m 3"
59
+ - "photo of a Christina Hendricks --m -1"
60
+ - "photo of a Christina Hendricks --m -1"
61
+ - "photo of a Christina Hendricks --m 1"
62
+ - "photo of a Christina Hendricks --m 3"
63
+ - "photo of a Christina Ricci --m -3"
64
+ - "photo of a Christina Ricci --m -1"
65
+ - "photo of a Christina Ricci --m 1"
66
+ - "photo of a Christina Ricci --m 3"
67
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
68
+ seed: 42
69
+ walk_seed: false
70
+ guidance_scale: 7
71
+ sample_steps: 20
72
+ network_multiplier: 1.0
73
+
74
+ logging:
75
+ log_every: 10 # log every this many steps
76
+ use_wandb: false # not supported yet
77
+ verbose: false
78
+
79
+ slider:
80
+ datasets:
81
+ - pair_folder: "/path/to/folder/side/by/side/images"
82
+ network_weight: 2.0
83
+ target_class: "" # only used as default if caption txt are not present
84
+ size: 512
85
+ - pair_folder: "/path/to/folder/side/by/side/images"
86
+ network_weight: 4.0
87
+ target_class: "" # only used as default if caption txt are not present
88
+ size: 512
89
+
90
+
91
+ # you can put any information you want here, and it will be saved in the model
92
+ # the below is an example. I recommend doing trigger words at a minimum
93
+ # in the metadata. The software will include this plus some other information
94
+ meta:
95
+ name: "[name]" # [name] gets replaced with the name above
96
+ description: A short description of your model
97
+ trigger_words:
98
+ - put
99
+ - trigger
100
+ - words
101
+ - here
102
+ version: '0.1'
103
+ creator:
104
+ name: Your Name
105
106
+ website: https://yourwebsite.com
107
+ any: All meta data above is arbitrary, it can be whatever you want.