diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cf45837f21585bed9e9040c490557e7f95f67e46 --- /dev/null +++ b/.gitignore @@ -0,0 +1,181 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +/env.sh +/models +/custom/* +!/custom/.gitkeep +/.tmp +/venv.bkp +/venv.* +/config/* +!/config/examples +!/config/_PUT_YOUR_CONFIGS_HERE).txt +/output/* +!/output/.gitkeep +/extensions/* +!/extensions/example +/temp +/wandb +.vscode/settings.json +.DS_Store +._.DS_Store +merge_file.py \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..657cf28b319df0b258ec737818f622e86eb44f16 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "repositories/sd-scripts"] + path = repositories/sd-scripts + url = https://github.com/kohya-ss/sd-scripts.git +[submodule "repositories/leco"] + path = repositories/leco + url = https://github.com/p1atdev/LECO +[submodule "repositories/batch_annotator"] + path = repositories/batch_annotator + url = https://github.com/ostris/batch-annotator +[submodule "repositories/ipadapter"] + path = repositories/ipadapter + url = https://github.com/tencent-ailab/IP-Adapter.git diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..a13ba46585573c3dc7c4cde94422633f251f04bf --- /dev/null +++ b/FAQ.md @@ -0,0 +1,10 @@ +# FAQ + +WIP. Will continue to add things as they are needed. + +## FLUX.1 Training + +#### How much VRAM is required to train a lora on FLUX.1? + +24GB minimum is required. + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d72f95d548901698a309fcf56d64c086ddc264cf --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Ostris, LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index d38b3fc58e3df169b3db5539c0899f2385eb953c..a926f6eb718129e03cad0f18b43de9474d9f90d8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,40 @@ ---- -title: Flux LoRA Trainning On Modal -emoji: 🏢 -colorFrom: blue -colorTo: red -sdk: gradio -sdk_version: 5.9.1 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# FLUX LoRA Training on Modal + +## IMPORTANT - READ THIS + +This Space allows you to train LoRA models for FLUX using Modal's GPU resources. You **must** use your own API tokens for this Space to function. + +## Setup Instructions: + +1. **Create Accounts:** Create free accounts on [Hugging Face](https://huggingface.co), [Modal](https://modal.com), and [Weights & Biases](https://wandb.ai) (if you want to save training info). +2. **Get API Tokens:** + * **Hugging Face:** Obtain a "write" access token from your Hugging Face [settings/tokens](https://huggingface.co/settings/tokens). + * **Modal:** Get your Modal API token from your [Modal dashboard](https://modal.com/). + * **WandB:** Generate a WandB API key from your Weights & Biases settings if you plan to use WandB. +3. **Duplicate This Space:** Duplicate (clone) this Hugging Face Space to your own account. +4. **Add API Tokens as Secrets:** In your duplicated space, navigate to "Settings" -> "Variables and Secrets" and add: + * `HF_TOKEN`: Your Hugging Face write access token. + * `WANDB_API_KEY`: Your Weights & Biases API key. +5. **Upload your dataset** + * You must upload your dataset to the `/root/ai-toolkit` folder. The images can be in a subfolder but that folder's path needs to begin with `/root/ai-toolkit/`. For instance: `/root/ai-toolkit/my-dataset` + * Make sure the image file names match with a corresponding text caption in the same folder. `image1.jpg` and `image1.txt` + * You can upload a zip file of your dataset, or just a collection of images and text files. +6. **Customize and Train:** + * Go to the `App` tab, and use the Gradio interface to train your LoRA. + * Enter the required information, dataset, and configure the training parameters. + * Choose to upload a zip file or multiple images + * Make sure your image file names match with a corresponding `.txt` file. + * Click the "Start Training" button and wait for the training to complete (check Modal for logs, WandB for training data). + * The UI will automatically upload the file(s) to the Modal compute environment +7. **View Results:** + * Trained LoRA models will be automatically pushed to your Hugging Face account if you enable the option and have the necessary write token set + * Samples, Logs, optimizer and other training information will be stored on WandB if enabled. + * Models, optimizer, and samples are always stored in `Storage > flux-lora-models` on Modal. + +## Notes + +* Training data will always be located in a dataset folder under `/root/ai-toolkit/your-dataset` in your config, this is required to train correctly. If you upload a folder name `my-dataset` as a zip, the folder path to reference will be: `/root/ai-toolkit/my-dataset`. +* Make sure the images have a corresponding text file with the same name +* Training and downloading samples in this model will take a while be patient, especially in low vram mode + +If you encounter any problems, please open a new issue on the [ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) github \ No newline at end of file diff --git a/README_origin.md b/README_origin.md new file mode 100644 index 0000000000000000000000000000000000000000..b19f5044e2c4e27652e587ab236109749b126a33 --- /dev/null +++ b/README_origin.md @@ -0,0 +1,468 @@ +# AI Toolkit by Ostris + +## IMPORTANT NOTE - READ THIS +This is my research repo. I do a lot of experiments in it and it is possible that I will break things. +If something breaks, checkout an earlier commit. This repo can train a lot of things, and it is +hard to keep up with all of them. + +## Support my work + + +glif.app + + + +My work on this project would not be possible without the amazing support of [Glif](https://glif.app/) and everyone on the +team. If you want to support me, support Glif. [Join the site](https://glif.app/), +[Join us on Discord](https://discord.com/invite/nuR9zZ2nsh), [follow us on Twitter](https://x.com/heyglif) +and come make some cool stuff with us + +## Installation + +Requirements: +- python >3.10 +- Nvidia GPU with enough ram to do what you need +- python venv +- git + + + +Linux: +```bash +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +git submodule update --init --recursive +python3 -m venv venv +source venv/bin/activate +# .\venv\Scripts\activate on windows +# install torch first +pip3 install torch +pip3 install -r requirements.txt +``` + +Windows: +```bash +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +git submodule update --init --recursive +python -m venv venv +.\venv\Scripts\activate +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +pip install -r requirements.txt +``` + +## FLUX.1 Training + +### Tutorial + +To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM. + + +### Requirements +You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control +your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize +the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL, +but there are some reports of a bug when running on windows natively. +I have only tested on linux for now. This is still extremely experimental +and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all. + +### FLUX.1-dev + +FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the +non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. +Otherwise, this will fail. Here are the required steps to setup a license. + +1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) +2. Make a file named `.env` in the root on this folder +3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here` + +### FLUX.1-schnell + +FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train. +However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter). +It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended. + +To use it, You just need to add the assistant to the `model` section of your config file like so: + +```yaml + model: + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" + is_flux: true + quantize: true +``` + +You also need to adjust your sample steps since schnell does not require as many + +```yaml + sample: + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +``` + +### Training +1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml` +2. Edit the file following the comments in the file +3. Run the file like so `python run.py config/whatever_you_want.yml` + +A folder with the name and the training folder from the config file will be created when you start. It will have all +checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up +from the last checkpoint. + +IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving + +### Need help? + +Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU) +and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord +and I will answer when I can. + +## Gradio UI + +To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed: + +```bash +cd ai-toolkit #in case you are not yet in the ai-toolkit folder +huggingface-cli login #provide a `write` token to publish your LoRA at the end +python flux_train_ui.py +``` + +You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA +![image](assets/lora_ease_ui.png) + + +## Training in RunPod +Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04** +> You need a minimum of 24GB VRAM, pick a GPU by your preference. + +#### Example config ($0.5/hr): +- 1x A40 (48 GB VRAM) +- 19 vCPU 100 GB RAM + +#### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples): +- ~120 GB Disk +- ~120 GB Pod Volume +- Start Jupyter Notebook + +### 1. Setup +``` +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +git submodule update --init --recursive +python -m venv venv +source venv/bin/activate +pip install torch +pip install -r requirements.txt +pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues +``` +### 2. Upload your dataset +- Create a new folder in the root, name it `dataset` or whatever you like. +- Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder. + +### 3. Login into Hugging Face with an Access Token +- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev). +- Run ```huggingface-cli login``` and paste your token. + +### 4. Training +- Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```. +- Edit the config following the comments in the file. +- Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```. +- Run the file: ```python run.py config/whatever_you_want.yml```. + +### Screenshot from RunPod +RunPod Training Screenshot + +## Training in Modal + +### 1. Setup +#### ai-toolkit: +``` +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +git submodule update --init --recursive +python -m venv venv +source venv/bin/activate +pip install torch +pip install -r requirements.txt +pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues +``` +#### Modal: +- Run `pip install modal` to install the modal Python package. +- Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`). + +#### Hugging Face: +- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev). +- Run `huggingface-cli login` and paste your token. + +### 2. Upload your dataset +- Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`. + +### 3. Configs +- Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```. +- Edit the config following the comments in the file, **be careful and follow the example `/root/ai-toolkit` paths**. + +### 4. Edit run_modal.py +- Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like: + + ``` + code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit") + ``` +- Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_. + +### 5. Training +- Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`. +- You can monitor your training in your local terminal, or on [modal.com](https://modal.com/). +- Models, samples and optimizer will be stored in `Storage > flux-lora-models`. + +### 6. Saving the model +- Check contents of the volume by running `modal volume ls flux-lora-models`. +- Download the content by running `modal volume get flux-lora-models your-model-name`. +- Example: `modal volume get flux-lora-models my_first_flux_lora_v1`. + +### Screenshot from Modal + +Modal Traning Screenshot + +--- + +## Dataset Preparation + +Datasets generally need to be a folder containing images and associated text files. Currently, the only supported +formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images +but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption. +You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically +replaced. + +Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**. +The loader will automatically resize them and can handle varying aspect ratios. + + +## Training Specific Layers + +To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers +used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your +network kwargs like so: + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + only_if_contains: + - "transformer.single_transformer_blocks.7.proj_out" + - "transformer.single_transformer_blocks.20.proj_out" +``` + +The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal +the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights. +For instance to only train the `single_transformer` for FLUX.1, you can use the following: + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + only_if_contains: + - "transformer.single_transformer_blocks." +``` + +You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks, + + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + ignore_if_contains: + - "transformer.single_transformer_blocks." +``` + +`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both, +if will be ignored. + +--- + +## EVERYTHING BELOW THIS LINE IS OUTDATED + +It may still work like that, but I have not tested it in a while. + +--- + +### Batch Image Generation + +A image generator that can take frompts from a config file or form a txt file and generate them to a +folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used +for generat batch image generation. +It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`. +Mere info is in the comments in the example + +--- + +### LoRA (lierla), LoCON (LyCORIS) extractor + +It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features +and LoRA (lierla) support. It can do multiple types of extractions in one run. +It all runs off a config file, which you can find an example of in `config/examples/extract.example.yml`. +Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`. +Then you can edit the file to your liking. and call it like so: + +```bash +python3 run.py config/whatever_you_want.yml +``` + +You can also put a full path to a config file, if you want to keep it somewhere else. + +```bash +python3 run.py "/home/user/whatever_you_want.yml" +``` + +More notes on how it works are available in the example config file itself. LoRA and LoCON both support +extractions of 'fixed', 'threshold', 'ratio', 'quantile'. I'll update what these do and mean later. +Most people used fixed, which is traditional fixed dimension extraction. + +`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc. + +--- + +### LoRA Rescale + +Change `` to `` or whatever you want with the same effect. +A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it. +It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`. +Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`. +Then you can edit the file to your liking. and call it like so: + +```bash +python3 run.py config/whatever_you_want.yml +``` + +You can also put a full path to a config file, if you want to keep it somewhere else. + +```bash +python3 run.py "/home/user/whatever_you_want.yml" +``` + +More notes on how it works are available in the example config file itself. This is useful when making +all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2 +or even -15 to 15. This will allow you to dile it in so they all have your desired scale + +--- + +### LoRA Slider Trainer + + + Open In Colab + + +This is how I train most of the recent sliders I have on Civitai, you can check them out in my [Civitai profile](https://civitai.com/user/Ostris/models). +It is based off the work by [p1atdev/LECO](https://github.com/p1atdev/LECO) and [rohitgandikota/erasing](https://github.com/rohitgandikota/erasing) +But has been heavily modified to create sliders rather than erasing concepts. I have a lot more plans on this, but it is +very functional as is. It is also very easy to use. Just copy the example config file in `config/examples/train_slider.example.yml` +to the `config` folder and rename it to `whatever_you_want.yml`. Then you can edit the file to your liking. and call it like so: + +```bash +python3 run.py config/whatever_you_want.yml +``` + +There is a lot more information in that example file. You can even run the example as is without any modifications to see +how it works. It will create a slider that turns all animals into dogs(neg) or cats(pos). Just run it like so: + +```bash +python3 run.py config/examples/train_slider.example.yml +``` + +And you will be able to see how it works without configuring anything. No datasets are required for this method. +I will post an better tutorial soon. + +--- + +## Extensions!! + +You can now make and share custom extensions. That run within this framework and have all the inbuilt tools +available to them. I will probably use this as the primary development method going +forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot +of the existing functionality as well to make everything modular. There is an example extension in the `extensions` +folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully +enough to get you started. To make an extension, just copy that example and replace all the things you need to. + + +### Model Merger - Example Extension +It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together +as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most +mergers can only do one model at a time and this one will take as many as you want to feed it. There is an +example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`. +and use it like any other config file. + +## WIP Tools + + +### VAE (Variational Auto Encoder) Trainer + +This works, but is not ready for others to use and therefore does not have an example config. +I am still working on it. I will update this when it is ready. +I am adding a lot of features for criteria that I have used in my image enlargement work. A Critic (discriminator), +content loss, style loss, and a few more. If you don't know, the VAE +for stable diffusion (yes even the MSE one, and SDXL), are horrible at smaller faces and it holds SD back. I will fix this. +I'll post more about this later with better examples later, but here is a quick test of a run through with various VAEs. +Just went in and out. It is much worse on smaller faces than shown here. + + + +--- + +## TODO +- [X] Add proper regs on sliders +- [X] Add SDXL support (base model only for now) +- [ ] Add plain erasing +- [ ] Make Textual inversion network trainer (network that spits out TI embeddings) + +--- + +## Change Log + +#### 2023-08-05 + - Huge memory rework and slider rework. Slider training is better thant ever with no more +ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient +accumulation. This makes it much faster and more stable. + - Updated the example config to be something more practical and more updated to current methods. It is now +a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on +6GB gpu now. Will test soon to verify. + + +#### 2021-10-20 + - Windows support bug fixes + - Extensions! Added functionality to make and share custom extensions for training, merging, whatever. +check out the example in the `extensions` folder. Read more about that above. + - Model Merging, provided via the example extension. + +#### 2023-08-03 +Another big refactor to make SD more modular. + +Made batch image generation script + +#### 2023-08-01 +Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so +Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable +at the moment, so hopefully there are not breaking changes. + +Unfortunately, I am too lazy to write a proper changelog with all the changes. + +I added SDXL training to sliders... but.. it does not work properly. +The slider training relies on a model's ability to understand that an unconditional (negative prompt) +means you do not want that concept in the output. SDXL does not understand this for whatever reason, +which makes separating out +concepts within the model hard. I am sure the community will find a way to fix this +over time, but for now, it is not +going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text +encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little +training without every experimental new paper added to it. The KISS principal. + + +#### 2023-07-30 +Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a +regularizer. You can set the network multiplier to force spread consistency at high weights + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fb49d1ec1424d167d9aae6d9c45deddd8ff12de4 --- /dev/null +++ b/app.py @@ -0,0 +1,6 @@ +import os +import sys +from hf_ui import demo + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/assets/glif.svg b/assets/glif.svg new file mode 100644 index 0000000000000000000000000000000000000000..b898c62e9fe79ca85e585f3f779ef367096a34a7 --- /dev/null +++ b/assets/glif.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/lora_ease_ui.png b/assets/lora_ease_ui.png new file mode 100644 index 0000000000000000000000000000000000000000..914b8dd541978df8f3cd46f64bb2d3f4b597521c Binary files /dev/null and b/assets/lora_ease_ui.png differ diff --git a/build_and_push_docker.yaml b/build_and_push_docker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bbda385ac2a467ec8ac2c861860606cb0e46ce50 --- /dev/null +++ b/build_and_push_docker.yaml @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." +# wait 2 seconds +sleep 2 +docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:latest -f docker/Dockerfile . +docker tag aitoolkit:latest ostris/aitoolkit:latest +docker push ostris/aitoolkit:latest \ No newline at end of file diff --git a/config/examples/extract.example.yml b/config/examples/extract.example.yml new file mode 100644 index 0000000000000000000000000000000000000000..52505bb9058d81d4be3881bb20ecf6da214f571f --- /dev/null +++ b/config/examples/extract.example.yml @@ -0,0 +1,75 @@ +--- +# this is in yaml format. You can use json if you prefer +# I like both but yaml is easier to read and write +# plus it has comments which is nice for documentation +job: extract # tells the runner what to do +config: + # the name will be used to create a folder in the output folder + # it will also replace any [name] token in the rest of this config + name: name_of_your_model + # can be hugging face model, a .ckpt, or a .safetensors + base_model: "/path/to/base/model.safetensors" + # can be hugging face model, a .ckpt, or a .safetensors + extract_model: "/path/to/model/to/extract/trained.safetensors" + # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model + output_folder: "/path/to/output/folder" + is_v2: false + dtype: fp16 # saved dtype + device: cpu # cpu, cuda:0, etc + + # processes can be chained like this to run multiple in a row + # they must all use same models above, but great for testing different + # sizes and typed of extractions. It is much faster as we already have the models loaded + process: + # process 1 + - type: locon # locon or lora (locon is lycoris) + filename: "[name]_64_32.safetensors" # will be put in output folder + dtype: fp16 + mode: fixed + linear: 64 + conv: 32 + + # process 2 + - type: locon + output_path: "/absolute/path/for/this/output.safetensors" # can be absolute + mode: ratio + linear: 0.2 + conv: 0.2 + + # process 3 + - type: locon + filename: "[name]_ratio_02.safetensors" + mode: quantile + linear: 0.5 + conv: 0.5 + + # process 4 + - type: lora # traditional lora extraction (lierla) with linear layers only + filename: "[name]_4.safetensors" + mode: fixed # fixed, ratio, quantile supported for lora as well + linear: 4 # lora dim or rank + # no conv for lora + + # process 5 + - type: lora + filename: "[name]_q05.safetensors" + mode: quantile + linear: 0.5 + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. diff --git a/config/examples/generate.example.yaml b/config/examples/generate.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a3e19efdfee6f6215f37bb741a7b12e7c5ad484 --- /dev/null +++ b/config/examples/generate.example.yaml @@ -0,0 +1,60 @@ +--- + +job: generate # tells the runner what to do +config: + name: "generate" # this is not really used anywhere currently but required by runner + process: + # process 1 + - type: to_folder # process images to a folder + output_folder: "output/gen" + device: cuda:0 # cpu, cuda:0, etc + generate: + # these are your defaults you can override most of them with flags + sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now + width: 1024 + height: 1024 + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: -1 # -1 is random + guidance_scale: 7 + sample_steps: 20 + ext: ".png" # .png, .jpg, .jpeg, .webp + + # here ate the flags you can use for prompts. Always start with + # your prompt first then add these flags after. You can use as many + # like + # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20 + # we will try to support all sd-scripts flags where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + + prompt_file: false # if true a txt file will be created next to images with prompt strings used + # prompts can also be a path to a text file with one prompt per line + # prompts: "/path/to/prompts.txt" + prompts: + - "photo of batman" + - "photo of superman" + - "photo of spiderman" + - "photo of a superhero --n batman superman spiderman" + + model: + # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt + # name_or_path: "runwayml/stable-diffusion-v1-5" + name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" + is_v2: false # for v2 models + is_v_pred: false # for v-prediction models (most v2 models) + is_xl: false # for SDXL models + dtype: bf16 diff --git a/config/examples/mod_lora_scale.yaml b/config/examples/mod_lora_scale.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f59ecc838b1e4a4465600a67ddb690331ba3255 --- /dev/null +++ b/config/examples/mod_lora_scale.yaml @@ -0,0 +1,48 @@ +--- +job: mod +config: + name: name_of_your_model_v1 + process: + - type: rescale_lora + # path to your current lora model + input_path: "/path/to/lora/lora.safetensors" + # output path for your new lora model, can be the same as input_path to replace + output_path: "/path/to/lora/output_lora_v1.safetensors" + # replaces meta with the meta below (plus minimum meta fields) + # if false, we will leave the meta alone except for updating hashes (sd-script hashes) + replace_meta: true + # how to adjust, we can scale the up_down weights or the alpha + # up_down is the default and probably the best, they will both net the same outputs + # would only affect rare NaN cases and maybe merging with old merge tools + scale_target: 'up_down' + # precision to save, fp16 is the default and standard + save_dtype: fp16 + # current_weight is the ideal weight you use as a multiplier when using the lora + # IE in automatic1111 the 6.0 is the current_weight + # you can do negatives here too if you want to flip the lora + current_weight: 6.0 + # target_weight is the ideal weight you use as a multiplier when using the lora + # instead of the one above. IE in automatic1111 instead of using + # we want to use so 1.0 is the target_weight + target_weight: 1.0 + + # base model for the lora + # this is just used to add meta so automatic111 knows which model it is for + # assume v1.5 if these are not set + is_xl: false + is_v2: false +meta: + # this is only used if you set replace_meta to true above + name: "[name]" # [name] gets replaced with the name above + description: A short description of your lora + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. diff --git a/config/examples/modal/modal_train_lora_flux_24gb.yaml b/config/examples/modal/modal_train_lora_flux_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51873de0b0b5e0e86153cc19c7561ba18272ec34 --- /dev/null +++ b/config/examples/modal/modal_train_lora_flux_24gb.yaml @@ -0,0 +1,96 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir: + - folder_path: "/root/ai-toolkit/your-dataset" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + # if you get an error, or get stuck while downloading, + # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and + # place it like "/root/ai-toolkit/FLUX.1-dev" + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml b/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d1e964fe9479e333566fd89ee8a4b2724231346 --- /dev/null +++ b/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml @@ -0,0 +1,98 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir: + - folder_path: "/root/ai-toolkit/your-dataset" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + # if you get an error, or get stuck while downloading, + # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and + # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter" + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training + is_flux: true + quantize: true # run 8bit mixed precision + # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_lora_flux_24gb.yaml b/config/examples/train_lora_flux_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e29402b2668b23a215f5ddc10d083e55def7c61 --- /dev/null +++ b/config/examples/train_lora_flux_24gb.yaml @@ -0,0 +1,96 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_lora_flux_schnell_24gb.yaml b/config/examples/train_lora_flux_schnell_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4aef078d61765e70a9c1820109075af82367d9f --- /dev/null +++ b/config/examples/train_lora_flux_schnell_24gb.yaml @@ -0,0 +1,98 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new bell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training + is_flux: true + quantize: true # run 8bit mixed precision + # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_lora_sd35_large_24gb.yaml b/config/examples/train_lora_sd35_large_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1766c39220a38134c9db8319d2b7e243f8b8f43 --- /dev/null +++ b/config/examples/train_lora_sd35_large_24gb.yaml @@ -0,0 +1,97 @@ +--- +# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE +job: extension +config: + # this name will be the folder and filename name + name: "my_first_sd3l_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 1024 ] + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # May not fully work with SD3 yet + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" + timestep_type: "linear" # linear or sigmoid + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for sd3, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "stabilityai/stable-diffusion-3.5-large" + is_v3: true + quantize: true # run 8bit mixed precision + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml new file mode 100644 index 0000000000000000000000000000000000000000..b36009175af6b384527a2fb57e0f316caba6049a --- /dev/null +++ b/config/examples/train_slider.example.yml @@ -0,0 +1,230 @@ +--- +# This is in yaml format. You can use json if you prefer +# I like both but yaml is easier to write +# Plus it has comments which is nice for documentation +# This is the config I use on my sliders, It is solid and tested +job: train +config: + # the name will be used to create a folder in the output folder + # it will also replace any [name] token in the rest of this config + name: detail_slider_v1 + # folder will be created with name above in folder below + # it can be relative to the project root or absolute + training_folder: "output/LoRA" + device: cuda:0 # cpu, cuda:0, etc + # for tensorboard logging, we will make a subfolder for this job + log_dir: "output/.tensorboard" + # you can stack processes for other jobs, It is not tested with sliders though + # just use one for now + process: + - type: slider # tells runner to run the slider process + # network is the LoRA network for a slider, I recommend to leave this be + network: + # network type lierla is traditional LoRA that works everywhere, only linear layers + type: "lierla" + # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good + linear: 8 + linear_alpha: 4 # Do about half of rank + # training config + train: + # this is also used in sampling. Stick with ddpm unless you know what you are doing + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + # how many steps to train. More is not always better. I rarely go over 1000 + steps: 500 + # I have had good results with 4e-4 to 1e-4 at 500 steps + lr: 2e-4 + # enables gradient checkpoint, saves vram, leave it on + gradient_checkpointing: true + # train the unet. I recommend leaving this true + train_unet: true + # train the text encoder. I don't recommend this unless you have a special use case + # for sliders we are adjusting representation of the concept (unet), + # not the description of it (text encoder) + train_text_encoder: false + # same as from sd-scripts, not fully tested but should speed up training + min_snr_gamma: 5.0 + # just leave unless you know what you are doing + # also supports "dadaptation" but set lr to 1 if you use that, + # but it learns too fast and I don't recommend it + optimizer: "adamw" + # only constant for now + lr_scheduler: "constant" + # we randomly denoise random num of steps form 1 to this number + # while training. Just leave it + max_denoising_steps: 40 + # works great at 1. I do 1 even with my 4090. + # higher may not work right with newer single batch stacking code anyway + batch_size: 1 + # bf16 works best if your GPU supports it (modern) + dtype: bf16 # fp32, bf16, fp16 + # if you have it, use it. It is faster and better + # torch 2.0 doesnt need xformers anymore, only use if you have lower version +# xformers: true + # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX + # although, the way we train sliders is comparative, so it probably won't work anyway + noise_offset: 0.0 +# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL + + # the model to train the LoRA network on + model: + # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt + name_or_path: "runwayml/stable-diffusion-v1-5" + is_v2: false # for v2 models + is_v_pred: false # for v-prediction models (most v2 models) + # has some issues with the dual text encoder and the way we train sliders + # it works bit weights need to probably be higher to see it. + is_xl: false # for SDXL models + + # saving config + save: + dtype: float16 # precision to save. I recommend float16 + save_every: 50 # save every this many steps + # this will remove step counts more than this number + # allows you to save more often in case of a crash without filling up your drive + max_step_saves_to_keep: 2 + + # sampling config + sample: + # must match train.noise_scheduler, this is not used here + # but may be in future and in other processes + sampler: "ddpm" + # sample every this many steps + sample_every: 20 + # image size + width: 512 + height: 512 + # prompts to use for sampling. Do as many as you want, but it slows down training + # pick ones that will best represent the concept you are trying to adjust + # allows some flags after the prompt + # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive + # slide are good tests. will inherit sample.network_multiplier if not set + # --n [string] # negative prompt, will inherit sample.neg if not set + # Only 75 tokens allowed currently + # I like to do a wide positive and negative spread so I can see a good range and stop + # early if the network is braking down + prompts: + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5" + - "a golden retriever sitting on a leather couch, --m -5" + - "a golden retriever sitting on a leather couch --m -3" + - "a golden retriever sitting on a leather couch --m 3" + - "a golden retriever sitting on a leather couch --m 5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5" + # negative prompt used on all prompts above as default if they don't have one + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome" + # seed for sampling. 42 is the answer for everything + seed: 42 + # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc + # will start over on next sample_every so s1 is always seed + # works well if you use same prompt but want different results + walk_seed: false + # cfg scale (4 to 10 is good) + guidance_scale: 7 + # sampler steps (20 to 30 is good) + sample_steps: 20 + # default network multiplier for all prompts + # since we are training a slider, I recommend overriding this with --m [number] + # in the prompts above to get both sides of the slider + network_multiplier: 1.0 + + # logging information + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false # probably done need unless you are debugging + + # slider training config, best for last + slider: + # resolutions to train on. [ width, height ]. This is less important for sliders + # as we are not teaching the model anything it doesn't already know + # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1 + # and [ 1024, 1024 ] for sd_xl + # you can do as many as you want here + resolutions: + - [ 512, 512 ] +# - [ 512, 768 ] +# - [ 768, 768 ] + # slider training uses 4 combined steps for a single round. This will do it in one gradient + # step. It is highly optimized and shouldn't take anymore vram than doing without it, + # since we break down batches for gradient accumulation now. so just leave it on. + batch_full_slide: true + # These are the concepts to train on. You can do as many as you want here, + # but they can conflict outweigh each other. Other than experimenting, I recommend + # just doing one for good results + targets: + # target_class is the base concept we are adjusting the representation of + # for example, if we are adjusting the representation of a person, we would use "person" + # if we are adjusting the representation of a cat, we would use "cat" It is not + # a keyword necessarily but what the model understands the concept to represent. + # "person" will affect men, women, children, etc but will not affect cats, dogs, etc + # it is the models base general understanding of the concept and everything it represents + # you can leave it blank to affect everything. In this example, we are adjusting + # detail, so we will leave it blank to affect everything + - target_class: "" + # positive is the prompt for the positive side of the slider. + # It is the concept that will be excited and amplified in the model when we slide the slider + # to the positive side and forgotten / inverted when we slide + # the slider to the negative side. It is generally best to include the target_class in + # the prompt. You want it to be the extreme of what you want to train on. For example, + # if you want to train on fat people, you would use "an extremely fat, morbidly obese person" + # as the prompt. Not just "fat person" + # max 75 tokens for now + positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" + # negative is the prompt for the negative side of the slider and works the same as positive + # it does not necessarily work the same as a negative prompt when generating images + # these need to be polar opposites. + # max 76 tokens for now + negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" + # the loss for this target is multiplied by this number. + # if you are doing more than one target it may be good to set less important ones + # to a lower number like 0.1 so they don't outweigh the primary target + weight: 1.0 + # shuffle the prompts split by the comma. We will run every combination randomly + # this will make the LoRA more robust. You probably want this on unless prompt order + # is important for some reason + shuffle: true + + + # anchors are prompts that we will try to hold on to while training the slider + # these are NOT necessary and can prevent the slider from converging if not done right + # leave them off if you are having issues, but they can help lock the network + # on certain concepts to help prevent catastrophic forgetting + # you want these to generate an image that is not your target_class, but close to it + # is fine as long as it does not directly overlap it. + # For example, if you are training on a person smiling, + # you could use "a person with a face mask" as an anchor. It is a person, the image is the same + # regardless if they are smiling or not, however, the closer the concept is to the target_class + # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually + # for close concepts, you want to be closer to 0.1 or 0.2 + # these will slow down training. I am leaving them off for the demo + +# anchors: +# - prompt: "a woman" +# neg_prompt: "animal" +# # the multiplier applied to the LoRA when this is run. +# # higher will give it more weight but also help keep the lora from collapsing +# multiplier: 1.0 +# - prompt: "a man" +# neg_prompt: "animal" +# multiplier: 1.0 +# - prompt: "a person" +# neg_prompt: "animal" +# multiplier: 1.0 + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e98572c45341e838c5fa0c66ef840db48ec48fd9 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,31 @@ +FROM runpod/base:0.6.2-cuda12.2.0 + +LABEL authors="jaret" + +# Install dependencies +RUN apt-get update + +WORKDIR /app +ARG CACHEBUST=1 +RUN git clone https://github.com/ostris/ai-toolkit.git && \ + cd ai-toolkit && \ + git submodule update --init --recursive + +WORKDIR /app/ai-toolkit + +RUN ln -s /usr/bin/python3 /usr/bin/python +RUN python -m pip install -r requirements.txt + +RUN apt-get install -y tmux nvtop htop + +RUN pip install jupyterlab + +# mask workspace +RUN mkdir /workspace + + +# symlink app to workspace +RUN ln -s /app/ai-toolkit /workspace/ai-toolkit + +WORKDIR / +CMD ["/start.sh"] \ No newline at end of file diff --git a/extensions/example/ExampleMergeModels.py b/extensions/example/ExampleMergeModels.py new file mode 100644 index 0000000000000000000000000000000000000000..162d514c38799b1c5cdc4717e1fa9867a4b35572 --- /dev/null +++ b/extensions/example/ExampleMergeModels.py @@ -0,0 +1,129 @@ +import torch +import gc +from collections import OrderedDict +from typing import TYPE_CHECKING +from jobs.process import BaseExtensionProcess +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +from tqdm import tqdm + +# Type check imports. Prevents circular imports +if TYPE_CHECKING: + from jobs import ExtensionJob + + +# extend standard config classes to add weight +class ModelInputConfig(ModelConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.weight = kwargs.get('weight', 1.0) + # overwrite default dtype unless user specifies otherwise + # float 32 will give up better precision on the merging functions + self.dtype: str = kwargs.get('dtype', 'float32') + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +# this is our main class process +class ExampleMergeModels(BaseExtensionProcess): + def __init__( + self, + process_id: int, + job: 'ExtensionJob', + config: OrderedDict + ): + super().__init__(process_id, job, config) + # this is the setup process, do not do process intensive stuff here, just variable setup and + # checking requirements. This is called before the run() function + # no loading models or anything like that, it is just for setting up the process + # all of your process intensive stuff should be done in the run() function + # config will have everything from the process item in the config file + + # convince methods exist on BaseProcess to get config values + # if required is set to true and the value is not found it will throw an error + # you can pass a default value to get_conf() as well if it was not in the config file + # as well as a type to cast the value to + self.save_path = self.get_conf('save_path', required=True) + self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype) + self.device = self.get_conf('device', default='cpu', as_type=torch.device) + + # build models to merge list + models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list) + # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config + # this way you can add methods to it and it is easier to read and code. There are a lot of + # inbuilt config classes located in toolkit.config_modules as well + self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge] + # setup is complete. Don't load anything else here, just setup variables and stuff + + # this is the entire run process be sure to call super().run() first + def run(self): + # always call first + super().run() + print(f"Running process: {self.__class__.__name__}") + + # let's adjust our weights first to normalize them so the total is 1.0 + total_weight = sum([model.weight for model in self.models_to_merge]) + weight_adjust = 1.0 / total_weight + for model in self.models_to_merge: + model.weight *= weight_adjust + + output_model: StableDiffusion = None + # let's do the merge, it is a good idea to use tqdm to show progress + for model_config in tqdm(self.models_to_merge, desc="Merging models"): + # setup model class with our helper class + sd_model = StableDiffusion( + device=self.device, + model_config=model_config, + dtype="float32" + ) + # load the model + sd_model.load_model() + + # adjust the weight of the text encoder + if isinstance(sd_model.text_encoder, list): + # sdxl model + for text_encoder in sd_model.text_encoder: + for key, value in text_encoder.state_dict().items(): + value *= model_config.weight + else: + # normal model + for key, value in sd_model.text_encoder.state_dict().items(): + value *= model_config.weight + # adjust the weights of the unet + for key, value in sd_model.unet.state_dict().items(): + value *= model_config.weight + + if output_model is None: + # use this one as the base + output_model = sd_model + else: + # merge the models + # text encoder + if isinstance(output_model.text_encoder, list): + # sdxl model + for i, text_encoder in enumerate(output_model.text_encoder): + for key, value in text_encoder.state_dict().items(): + value += sd_model.text_encoder[i].state_dict()[key] + else: + # normal model + for key, value in output_model.text_encoder.state_dict().items(): + value += sd_model.text_encoder.state_dict()[key] + # unet + for key, value in output_model.unet.state_dict().items(): + value += sd_model.unet.state_dict()[key] + + # remove the model to free memory + del sd_model + flush() + + # merge loop is done, let's save the model + print(f"Saving merged model to {self.save_path}") + output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype) + print(f"Saved merged model to {self.save_path}") + # do cleanup here + del output_model + flush() diff --git a/extensions/example/__init__.py b/extensions/example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34f348f1a1c278d7b71d48a60d5ddf909141b7db --- /dev/null +++ b/extensions/example/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ExampleMergeExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "example_merge_extension" + + # name is the name of the extension for printing + name = "Example Merge Extension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ExampleMergeModels import ExampleMergeModels + return ExampleMergeModels + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ExampleMergeExtension +] diff --git a/extensions/example/config/config.example.yaml b/extensions/example/config/config.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..abed03fd9e9197e072a73203c67ad7ee976912cd --- /dev/null +++ b/extensions/example/config/config.example.yaml @@ -0,0 +1,48 @@ +--- +# Always include at least one example config file to show how to use your extension. +# use plenty of comments so users know how to use it and what everything does + +# all extensions will use this job name +job: extension +config: + name: 'my_awesome_merge' + process: + # Put your example processes here. This will be passed + # to your extension process in the config argument. + # the type MUST match your extension uid + - type: "example_merge_extension" + # save path for the merged model + save_path: "output/merge/[name].safetensors" + # save type + dtype: fp16 + # device to run it on + device: cuda:0 + # input models can only be SD1.x and SD2.x models for this example (currently) + models_to_merge: + # weights are relative, total weights will be normalized + # for example. If you have 2 models with weight 1.0, they will + # both be weighted 0.5. If you have 1 model with weight 1.0 and + # another with weight 2.0, the first will be weighted 1/3 and the + # second will be weighted 2/3 + - name_or_path: "input/model1.safetensors" + weight: 1.0 + - name_or_path: "input/model2.safetensors" + weight: 1.0 + - name_or_path: "input/model3.safetensors" + weight: 0.3 + - name_or_path: "input/model4.safetensors" + weight: 1.0 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/extensions_built_in/advanced_generator/Img2ImgGenerator.py b/extensions_built_in/advanced_generator/Img2ImgGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..58713bd0c570e1240811cefd8e8b1f390346c069 --- /dev/null +++ b/extensions_built_in/advanced_generator/Img2ImgGenerator.py @@ -0,0 +1,256 @@ +import math +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from diffusers.utils.torch_utils import randn_tensor +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image +from torchvision.transforms import ToTensor + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + + + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.denoise_strength = kwargs.get('denoise_strength', 0.5) + self.trigger_word = kwargs.get('trigger_word', None) + + +class Img2ImgGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.copy_inputs_to = self.get_conf('copy_inputs_to', None) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def to_pil(self, img): + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + return image + + def run(self): + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + if self.model_config.is_xl: + pipe = StableDiffusionXLImg2ImgPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + ).to(device, dtype=self.torch_dtype) + elif self.model_config.is_pixart: + pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype) + else: + raise NotImplementedError("Only XL models are supported") + pipe.set_progress_bar_config(disable=True) + + # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1) + generator = torch.manual_seed(gen_seed) + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + img_filename = img_filename_no_ext + '.' + self.generate_config.ext + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + + if self.copy_inputs_to is not None: + output_inputs_path = os.path.join(self.copy_inputs_to, img_filename) + output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt') + else: + output_inputs_path = None + output_inputs_caption_path = None + + caption = batch.get_caption_list()[0] + if self.generate_config.trigger_word is not None: + caption = caption.replace('[trigger]', self.generate_config.trigger_word) + + img: torch.Tensor = batch.tensor.clone() + image = self.to_pil(img) + + # image.save(output_depth_path) + if self.model_config.is_pixart: + pipe: PixArtSigmaPipeline = pipe + + # Encode the full image once + encoded_image = pipe.vae.encode( + pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype)) + if hasattr(encoded_image, "latent_dist"): + latents = encoded_image.latent_dist.sample(generator) + elif hasattr(encoded_image, "latents"): + latents = encoded_image.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + latents = pipe.vae.config.scaling_factor * latents + + # latents = self.sd.encode_images(img) + + # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps) + # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength) + # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0) + # timestep = timestep.to(device, dtype=torch.int32) + # latent = latent.to(device, dtype=self.torch_dtype) + # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype) + # latent = self.sd.add_noise(latent, noise, timestep) + # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:] + batch_size = 1 + num_images_per_prompt = 1 + + shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor, + image.width // pipe.vae_scale_factor) + noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype) + + # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype) + num_inference_steps = self.generate_config.sample_steps + strength = self.generate_config.denoise_strength + # Get timesteps + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + pipe.scheduler.set_timesteps(num_inference_steps, device="cpu") + timesteps = pipe.scheduler.timesteps[t_start:] + timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = pipe.scheduler.add_noise(latents, noise, timestep) + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + latents=latents, + timesteps=timesteps, + width=image.width, + height=image.height, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + guidance_scale=self.generate_config.guidance_scale, + # strength=self.generate_config.denoise_strength, + use_resolution_binning=False, + output_type="np" + ).images[0] + gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8) + gen_images = Image.fromarray(gen_images) + else: + pipe: StableDiffusionXLImg2ImgPipeline = pipe + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + strength=self.generate_config.denoise_strength, + ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + if output_inputs_path is not None: + os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True) + image.save(output_inputs_path) + with open(output_inputs_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/PureLoraGenerator.py b/extensions_built_in/advanced_generator/PureLoraGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..ec19da317230f5a837ebc6561fd4dc7cac2fd946 --- /dev/null +++ b/extensions_built_in/advanced_generator/PureLoraGenerator.py @@ -0,0 +1,102 @@ +import os +from collections import OrderedDict + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig +from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.train_tools import get_torch_dtype + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class PureLoraGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.device_torch = torch.device(self.device) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + lorm_config = self.get_conf('lorm', None) + self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None + + self.device_state_preset = get_train_sd_device_state_preset( + device=torch.device(self.device), + ) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + + def run(self): + super().run() + print("Loading model...") + with torch.no_grad(): + self.sd.load_model() + self.sd.unet.eval() + self.sd.unet.to(self.device_torch) + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + te.to(self.device_torch) + else: + self.sd.text_encoder.eval() + self.sd.to(self.device_torch) + + print(f"Converting to LoRM UNet") + # replace the unet with LoRMUnet + convert_diffusers_unet_to_lorm( + self.sd.unet, + config=self.lorm_config, + ) + + sample_folder = os.path.join(self.output_folder) + gen_img_config_list = [] + + sample_config = self.generate_config + start_seed = sample_config.seed + current_seed = start_seed + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i + + filename = f"[time]_[count].{self.generate_config.ext}" + output_path = os.path.join(sample_folder, filename) + prompt = sample_config.prompts[i] + extra_args = {} + gen_img_config_list.append(GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + width=sample_config.width, + height=sample_config.height, + negative_prompt=sample_config.neg, + seed=current_seed, + guidance_scale=sample_config.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_config.sample_steps, + network_multiplier=sample_config.network_multiplier, + output_path=output_path, + output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + **extra_args + )) + + # send to be generated + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/ReferenceGenerator.py b/extensions_built_in/advanced_generator/ReferenceGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..19e3b6e55786cde2ddcebfdc60230223dc443b62 --- /dev/null +++ b/extensions_built_in/advanced_generator/ReferenceGenerator.py @@ -0,0 +1,212 @@ +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + + +class ReferenceGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def run(self): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + if self.generate_config.t2i_adapter_path is not None: + self.adapter = T2IAdapter.from_pretrained( + self.generate_config.t2i_adapter_path, + torch_dtype=self.torch_dtype, + varient="fp16" + ).to(device) + + midas_depth = MidasDetector.from_pretrained( + "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" + ).to(device) + + if self.model_config.is_xl: + pipe = StableDiffusionXLAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) + else: + pipe = StableDiffusionAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=get_sampler(self.generate_config.sampler), + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) + pipe.set_progress_bar_config(disable=True) + + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png') + + caption = batch.get_caption_list()[0] + + img: torch.Tensor = batch.tensor.clone() + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + + width, height = image.size + min_res = min(width, height) + + if self.generate_config.walk_seed: + seed = seed + 1 + + if self.generate_config.seed == -1: + # random + seed = random.randint(0, 1000000) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # generate depth map + image = midas_depth( + image, + detect_resolution=min_res, # do 512 ? + image_resolution=min_res + ) + + # image.save(output_depth_path) + + gen_images = pipe( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale, + guidance_scale=self.generate_config.guidance_scale, + ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/__init__.py b/extensions_built_in/advanced_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65811655a991c84551d2915c7b04fa66c0d8fbaa --- /dev/null +++ b/extensions_built_in/advanced_generator/__init__.py @@ -0,0 +1,59 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class AdvancedReferenceGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "reference_generator" + + # name is the name of the extension for printing + name = "Reference Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ReferenceGenerator import ReferenceGenerator + return ReferenceGenerator + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class PureLoraGenerator(Extension): + # uid must be unique, it is how the extension is identified + uid = "pure_lora_generator" + + # name is the name of the extension for printing + name = "Pure LoRA Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .PureLoraGenerator import PureLoraGenerator + return PureLoraGenerator + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class Img2ImgGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "batch_img2img" + + # name is the name of the extension for printing + name = "Img2ImgGeneratorExtension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .Img2ImgGenerator import Img2ImgGenerator + return Img2ImgGenerator + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension +] diff --git a/extensions_built_in/advanced_generator/config/train.example.yaml b/extensions_built_in/advanced_generator/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/extensions_built_in/advanced_generator/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/extensions_built_in/concept_replacer/ConceptReplacer.py b/extensions_built_in/concept_replacer/ConceptReplacer.py new file mode 100644 index 0000000000000000000000000000000000000000..1600e8e1851c402f5468d6c48fdc41ec2d4487fb --- /dev/null +++ b/extensions_built_in/concept_replacer/ConceptReplacer.py @@ -0,0 +1,151 @@ +import random +from collections import OrderedDict +from torch.utils.data import DataLoader +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +import torch +from jobs.process import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ConceptReplacementConfig: + def __init__(self, **kwargs): + self.concept: str = kwargs.get('concept', '') + self.replacement: str = kwargs.get('replacement', '') + + +class ConceptReplacer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + replacement_list = self.config.get('replacements', []) + self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # textual inversion + if self.embedding is not None: + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + + def hook_train_loop(self, batch): + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + network_weight_list = batch.get_network_weight_list() + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + batch_replacement_list = [] + # get a random replacement for each prompt + for prompt in conditioned_prompts: + replacement = random.choice(self.replacement_list) + batch_replacement_list.append(replacement) + + # build out prompts + concept_prompts = [] + replacement_prompts = [] + for idx, replacement in enumerate(batch_replacement_list): + prompt = conditioned_prompts[idx] + + # insert shuffled concept at beginning and end of prompt + shuffled_concept = [x.strip() for x in replacement.concept.split(',')] + random.shuffle(shuffled_concept) + shuffled_concept = ', '.join(shuffled_concept) + concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}") + + # insert replacement at beginning and end of prompt + shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')] + random.shuffle(shuffled_replacement) + shuffled_replacement = ', '.join(shuffled_replacement) + replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}") + + # predict the replacement without network + conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype) + + replacement_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + del conditional_embeds + replacement_pred = replacement_pred.detach() + + self.optimizer.zero_grad() + flush() + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding: + grad_on_text_encoder = True + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + with network: + with torch.set_grad_enabled(grad_on_text_encoder): + # embed the prompts + conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype) + if not grad_on_text_encoder: + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + self.optimizer.zero_grad() + flush() + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.embedding is not None: + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + # reset network multiplier + network.multiplier = 1.0 + + return loss_dict diff --git a/extensions_built_in/concept_replacer/__init__.py b/extensions_built_in/concept_replacer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69dc731141158741d7b5165a7e5dc2f77b467051 --- /dev/null +++ b/extensions_built_in/concept_replacer/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptReplacerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_replacer" + + # name is the name of the extension for printing + name = "Concept Replacer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptReplacer import ConceptReplacer + return ConceptReplacer + + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptReplacerExtension, +] diff --git a/extensions_built_in/concept_replacer/config/train.example.yaml b/extensions_built_in/concept_replacer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/extensions_built_in/concept_replacer/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/extensions_built_in/dataset_tools/DatasetTools.py b/extensions_built_in/dataset_tools/DatasetTools.py new file mode 100644 index 0000000000000000000000000000000000000000..d969b77aa9fa9e0fcce001e2fb5d102d60058f6a --- /dev/null +++ b/extensions_built_in/dataset_tools/DatasetTools.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +import gc +import torch +from jobs.process import BaseExtensionProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class DatasetTools(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + + raise NotImplementedError("This extension is not yet implemented") diff --git a/extensions_built_in/dataset_tools/SuperTagger.py b/extensions_built_in/dataset_tools/SuperTagger.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb3c70e57ff3d2cc6b5dd23400d9a30b72243c7 --- /dev/null +++ b/extensions_built_in/dataset_tools/SuperTagger.py @@ -0,0 +1,196 @@ +import copy +import json +import os +from collections import OrderedDict +import gc +import traceback +import torch +from PIL import Image, ImageOps +from tqdm import tqdm + +from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo +from .tools.fuyu_utils import FuyuImageProcessor +from .tools.image_tools import load_image, ImageProcessor, resize_to_max +from .tools.llava_utils import LLaVAImageProcessor +from .tools.caption import default_long_prompt, default_short_prompt, default_replacements +from jobs.process import BaseExtensionProcess +from .tools.sync_tools import get_img_paths + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +VERSION = 2 + + +class SuperTagger(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + parent_dir = config.get('parent_dir', None) + self.dataset_paths: list[str] = config.get('dataset_paths', []) + self.device = config.get('device', 'cuda') + self.steps: list[Step] = config.get('steps', []) + self.caption_method = config.get('caption_method', 'llava:default') + self.caption_prompt = config.get('caption_prompt', default_long_prompt) + self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt) + self.force_reprocess_img = config.get('force_reprocess_img', False) + self.caption_replacements = config.get('caption_replacements', default_replacements) + self.caption_short_replacements = config.get('caption_short_replacements', default_replacements) + self.master_dataset_dict = OrderedDict() + self.dataset_master_config_file = config.get('dataset_master_config_file', None) + if parent_dir is not None and len(self.dataset_paths) == 0: + # find all folders in the patent_dataset_path + self.dataset_paths = [ + os.path.join(parent_dir, folder) + for folder in os.listdir(parent_dir) + if os.path.isdir(os.path.join(parent_dir, folder)) + ] + else: + # make sure they exist + for dataset_path in self.dataset_paths: + if not os.path.exists(dataset_path): + raise ValueError(f"Dataset path does not exist: {dataset_path}") + + print(f"Found {len(self.dataset_paths)} dataset paths") + + self.image_processor: ImageProcessor = self.get_image_processor() + + def get_image_processor(self): + if self.caption_method.startswith('llava'): + return LLaVAImageProcessor(device=self.device) + elif self.caption_method.startswith('fuyu'): + return FuyuImageProcessor(device=self.device) + else: + raise ValueError(f"Unknown caption method: {self.caption_method}") + + def process_image(self, img_path: str): + root_img_dir = os.path.dirname(os.path.dirname(img_path)) + filename = os.path.basename(img_path) + filename_no_ext = os.path.splitext(filename)[0] + train_dir = os.path.join(root_img_dir, TRAIN_DIR) + train_img_path = os.path.join(train_dir, filename) + json_path = os.path.join(train_dir, f"{filename_no_ext}.json") + + # check if json exists, if it does load it as image info + if os.path.exists(json_path): + with open(json_path, 'r') as f: + img_info = ImgInfo(**json.load(f)) + else: + img_info = ImgInfo() + + # always send steps first in case other processes need them + img_info.add_steps(copy.deepcopy(self.steps)) + img_info.set_version(VERSION) + img_info.set_caption_method(self.caption_method) + + image: Image = None + caption_image: Image = None + + did_update_image = False + + # trigger reprocess of steps + if self.force_reprocess_img: + img_info.trigger_image_reprocess() + + # set the image as updated if it does not exist on disk + if not os.path.exists(train_img_path): + did_update_image = True + image = load_image(img_path) + if img_info.force_image_process: + did_update_image = True + image = load_image(img_path) + + # go through the needed steps + for step in copy.deepcopy(img_info.state.steps_to_complete): + if step == 'caption': + # load image + if image is None: + image = load_image(img_path) + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) + + if not self.image_processor.is_loaded: + print('Loading Model. Takes a while, especially the first time') + self.image_processor.load_model() + + img_info.caption = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_prompt, + replacements=self.caption_replacements + ) + img_info.mark_step_complete(step) + elif step == 'caption_short': + # load image + if image is None: + image = load_image(img_path) + + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) + + if not self.image_processor.is_loaded: + print('Loading Model. Takes a while, especially the first time') + self.image_processor.load_model() + img_info.caption_short = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_short_prompt, + replacements=self.caption_short_replacements + ) + img_info.mark_step_complete(step) + elif step == 'contrast_stretch': + # load image + if image is None: + image = load_image(img_path) + image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True) + did_update_image = True + img_info.mark_step_complete(step) + else: + raise ValueError(f"Unknown step: {step}") + + os.makedirs(os.path.dirname(train_img_path), exist_ok=True) + if did_update_image: + image.save(train_img_path) + + if img_info.is_dirty: + with open(json_path, 'w') as f: + json.dump(img_info.to_dict(), f, indent=4) + + if self.dataset_master_config_file: + # add to master dict + self.master_dataset_dict[train_img_path] = img_info.to_dict() + + def run(self): + super().run() + imgs_to_process = [] + # find all images + for dataset_path in self.dataset_paths: + raw_dir = os.path.join(dataset_path, RAW_DIR) + raw_image_paths = get_img_paths(raw_dir) + for raw_image_path in raw_image_paths: + imgs_to_process.append(raw_image_path) + + if len(imgs_to_process) == 0: + print(f"No images to process") + else: + print(f"Found {len(imgs_to_process)} to process") + + for img_path in tqdm(imgs_to_process, desc="Processing images"): + try: + self.process_image(img_path) + except Exception: + # print full stack trace + print(traceback.format_exc()) + continue + # self.process_image(img_path) + + if self.dataset_master_config_file is not None: + # save it as json + with open(self.dataset_master_config_file, 'w') as f: + json.dump(self.master_dataset_dict, f, indent=4) + + del self.image_processor + flush() diff --git a/extensions_built_in/dataset_tools/SyncFromCollection.py b/extensions_built_in/dataset_tools/SyncFromCollection.py new file mode 100644 index 0000000000000000000000000000000000000000..e65a35848e933fdab843d8677b6e5000e1393825 --- /dev/null +++ b/extensions_built_in/dataset_tools/SyncFromCollection.py @@ -0,0 +1,131 @@ +import os +import shutil +from collections import OrderedDict +import gc +from typing import List + +import torch +from tqdm import tqdm + +from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR +from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \ + get_img_paths +from jobs.process import BaseExtensionProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class SyncFromCollection(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + self.min_width = config.get('min_width', 1024) + self.min_height = config.get('min_height', 1024) + + # add our min_width and min_height to each dataset config if they don't exist + for dataset_config in config.get('dataset_sync', []): + if 'min_width' not in dataset_config: + dataset_config['min_width'] = self.min_width + if 'min_height' not in dataset_config: + dataset_config['min_height'] = self.min_height + + self.dataset_configs: List[DatasetSyncCollectionConfig] = [ + DatasetSyncCollectionConfig(**dataset_config) + for dataset_config in config.get('dataset_sync', []) + ] + print(f"Found {len(self.dataset_configs)} dataset configs") + + def move_new_images(self, root_dir: str): + raw_dir = os.path.join(root_dir, RAW_DIR) + new_dir = os.path.join(root_dir, NEW_DIR) + new_images = get_img_paths(new_dir) + + for img_path in new_images: + # move to raw + new_path = os.path.join(raw_dir, os.path.basename(img_path)) + shutil.move(img_path, new_path) + + # remove new dir + shutil.rmtree(new_dir) + + def sync_dataset(self, config: DatasetSyncCollectionConfig): + if config.host == 'unsplash': + get_images = get_unsplash_images + elif config.host == 'pexels': + get_images = get_pexels_images + else: + raise ValueError(f"Unknown host: {config.host}") + + results = { + 'num_downloaded': 0, + 'num_skipped': 0, + 'bad': 0, + 'total': 0, + } + + photos = get_images(config) + raw_dir = os.path.join(config.directory, RAW_DIR) + new_dir = os.path.join(config.directory, NEW_DIR) + raw_images = get_local_image_file_names(raw_dir) + new_images = get_local_image_file_names(new_dir) + + for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"): + try: + if photo.filename not in raw_images and photo.filename not in new_images: + download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height) + results['num_downloaded'] += 1 + else: + results['num_skipped'] += 1 + except Exception as e: + print(f" - BAD({photo.id}): {e}") + results['bad'] += 1 + continue + results['total'] += 1 + + return results + + def print_results(self, results): + print( + f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}") + + def run(self): + super().run() + print(f"Syncing {len(self.dataset_configs)} datasets") + all_results = None + failed_datasets = [] + for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True): + try: + results = self.sync_dataset(dataset_config) + if all_results is None: + all_results = {**results} + else: + for key, value in results.items(): + all_results[key] += value + + self.print_results(results) + except Exception as e: + print(f" - FAILED: {e}") + if 'response' in e.__dict__: + error = f"{e.response.status_code}: {e.response.text}" + print(f" - {error}") + failed_datasets.append({'dataset': dataset_config, 'error': error}) + else: + failed_datasets.append({'dataset': dataset_config, 'error': str(e)}) + continue + + print("Moving new images to raw") + for dataset_config in self.dataset_configs: + self.move_new_images(dataset_config.directory) + + print("Done syncing datasets") + self.print_results(all_results) + + if len(failed_datasets) > 0: + print(f"Failed to sync {len(failed_datasets)} datasets") + for failed in failed_datasets: + print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}") + print(f" - ERR: {failed['error']}") diff --git a/extensions_built_in/dataset_tools/__init__.py b/extensions_built_in/dataset_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b86d3cf5741ddd1892fda4fe24003012d770548b --- /dev/null +++ b/extensions_built_in/dataset_tools/__init__.py @@ -0,0 +1,43 @@ +from toolkit.extension import Extension + + +class DatasetToolsExtension(Extension): + uid = "dataset_tools" + + # name is the name of the extension for printing + name = "Dataset Tools" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DatasetTools import DatasetTools + return DatasetTools + + +class SyncFromCollectionExtension(Extension): + uid = "sync_from_collection" + name = "Sync from Collection" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SyncFromCollection import SyncFromCollection + return SyncFromCollection + + +class SuperTaggerExtension(Extension): + uid = "super_tagger" + name = "Super Tagger" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SuperTagger import SuperTagger + return SuperTagger + + +AI_TOOLKIT_EXTENSIONS = [ + SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension +] diff --git a/extensions_built_in/dataset_tools/tools/caption.py b/extensions_built_in/dataset_tools/tools/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..370786a80a380749863dbfc4449bdb33c09a5daa --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/caption.py @@ -0,0 +1,53 @@ + +caption_manipulation_steps = ['caption', 'caption_short'] + +default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.' +default_short_prompt = 'caption this image in less than ten words' + +default_replacements = [ + ("the image features", ""), + ("the image shows", ""), + ("the image depicts", ""), + ("the image is", ""), + ("in this image", ""), + ("in the image", ""), +] + + +def clean_caption(cap, replacements=None): + if replacements is None: + replacements = default_replacements + + # remove any newlines + cap = cap.replace("\n", ", ") + cap = cap.replace("\r", ", ") + cap = cap.replace(".", ",") + cap = cap.replace("\"", "") + + # remove unicode characters + cap = cap.encode('ascii', 'ignore').decode('ascii') + + # make lowercase + cap = cap.lower() + # remove any extra spaces + cap = " ".join(cap.split()) + + for replacement in replacements: + if replacement[0].startswith('*'): + # we are removing all text if it starts with this and the rest matches + search_text = replacement[0][1:] + if cap.startswith(search_text): + cap = "" + else: + cap = cap.replace(replacement[0].lower(), replacement[1].lower()) + + cap_list = cap.split(",") + # trim whitespace + cap_list = [c.strip() for c in cap_list] + # remove empty strings + cap_list = [c for c in cap_list if c != ""] + # remove duplicates + cap_list = list(dict.fromkeys(cap_list)) + # join back together + cap = ", ".join(cap_list) + return cap \ No newline at end of file diff --git a/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py b/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..60c69dbb503c724b2ecaa8fdfd5a702b43cee87c --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py @@ -0,0 +1,187 @@ +import json +from typing import Literal, Type, TYPE_CHECKING + +Host: Type = Literal['unsplash', 'pexels'] + +RAW_DIR = "raw" +NEW_DIR = "_tmp" +TRAIN_DIR = "train" +DEPTH_DIR = "depth" + +from .image_tools import Step, img_manipulation_steps +from .caption import caption_manipulation_steps + + +class DatasetSyncCollectionConfig: + def __init__(self, **kwargs): + self.host: Host = kwargs.get('host', None) + self.collection_id: str = kwargs.get('collection_id', None) + self.directory: str = kwargs.get('directory', None) + self.api_key: str = kwargs.get('api_key', None) + self.min_width: int = kwargs.get('min_width', 1024) + self.min_height: int = kwargs.get('min_height', 1024) + + if self.host is None: + raise ValueError("host is required") + if self.collection_id is None: + raise ValueError("collection_id is required") + if self.directory is None: + raise ValueError("directory is required") + if self.api_key is None: + raise ValueError(f"api_key is required: {self.host}:{self.collection_id}") + + +class ImageState: + def __init__(self, **kwargs): + self.steps_complete: list[Step] = kwargs.get('steps_complete', []) + self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', []) + + def to_dict(self): + return { + 'steps_complete': self.steps_complete + } + + +class Rect: + def __init__(self, **kwargs): + self.x = kwargs.get('x', 0) + self.y = kwargs.get('y', 0) + self.width = kwargs.get('width', 0) + self.height = kwargs.get('height', 0) + + def to_dict(self): + return { + 'x': self.x, + 'y': self.y, + 'width': self.width, + 'height': self.height + } + + +class ImgInfo: + def __init__(self, **kwargs): + self.version: int = kwargs.get('version', None) + self.caption: str = kwargs.get('caption', None) + self.caption_short: str = kwargs.get('caption_short', None) + self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])] + self.state = ImageState(**kwargs.get('state', {})) + self.caption_method = kwargs.get('caption_method', None) + self.other_captions = kwargs.get('other_captions', {}) + self._upgrade_state() + self.force_image_process: bool = False + self._requested_steps: list[Step] = [] + + self.is_dirty: bool = False + + def _upgrade_state(self): + # upgrades older states + if self.caption is not None and 'caption' not in self.state.steps_complete: + self.mark_step_complete('caption') + self.is_dirty = True + if self.caption_short is not None and 'caption_short' not in self.state.steps_complete: + self.mark_step_complete('caption_short') + self.is_dirty = True + if self.caption_method is None and self.caption is not None: + # added caption method in version 2. Was all llava before that + self.caption_method = 'llava:default' + self.is_dirty = True + + def to_dict(self): + return { + 'version': self.version, + 'caption_method': self.caption_method, + 'caption': self.caption, + 'caption_short': self.caption_short, + 'poi': [poi.to_dict() for poi in self.poi], + 'state': self.state.to_dict(), + 'other_captions': self.other_captions + } + + def mark_step_complete(self, step: Step): + if step not in self.state.steps_complete: + self.state.steps_complete.append(step) + if step in self.state.steps_to_complete: + self.state.steps_to_complete.remove(step) + self.is_dirty = True + + def add_step(self, step: Step): + if step not in self.state.steps_to_complete and step not in self.state.steps_complete: + self.state.steps_to_complete.append(step) + + def trigger_image_reprocess(self): + if self._requested_steps is None: + raise Exception("Must call add_steps before trigger_image_reprocess") + steps = self._requested_steps + # remove all image manipulationf from steps_to_complete + for step in img_manipulation_steps: + if step in self.state.steps_to_complete: + self.state.steps_to_complete.remove(step) + if step in self.state.steps_complete: + self.state.steps_complete.remove(step) + self.force_image_process = True + self.is_dirty = True + # we want to keep the order passed in process file + for step in steps: + if step in img_manipulation_steps: + self.add_step(step) + + def add_steps(self, steps: list[Step]): + self._requested_steps = [step for step in steps] + for stage in steps: + self.add_step(stage) + + # update steps if we have any img processes not complete, we have to reprocess them all + # if any steps_to_complete are in img_manipulation_steps + + is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete]) + order_has_changed = False + + if not is_manipulating_image: + # check to see if order has changed. No need to if already redoing it. Will detect if ones are removed + target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps] + current_img_manipulation_order = [step for step in self.state.steps_complete if + step in img_manipulation_steps] + if target_img_manipulation_order != current_img_manipulation_order: + order_has_changed = True + + if is_manipulating_image or order_has_changed: + self.trigger_image_reprocess() + + def set_caption_method(self, method: str): + if self._requested_steps is None: + raise Exception("Must call add_steps before set_caption_method") + if self.caption_method != method: + self.is_dirty = True + # move previous caption method to other_captions + if self.caption_method is not None and self.caption is not None or self.caption_short is not None: + self.other_captions[self.caption_method] = { + 'caption': self.caption, + 'caption_short': self.caption_short, + } + self.caption_method = method + self.caption = None + self.caption_short = None + # see if we have a caption from the new method + if method in self.other_captions: + self.caption = self.other_captions[method].get('caption', None) + self.caption_short = self.other_captions[method].get('caption_short', None) + else: + self.trigger_new_caption() + + def trigger_new_caption(self): + self.caption = None + self.caption_short = None + self.is_dirty = True + # check to see if we have any steps in the complete list and move them to the to_complete list + for step in self.state.steps_complete: + if step in caption_manipulation_steps: + self.state.steps_complete.remove(step) + self.state.steps_to_complete.append(step) + + def to_json(self): + return json.dumps(self.to_dict()) + + def set_version(self, version: int): + if self.version != version: + self.is_dirty = True + self.version = version diff --git a/extensions_built_in/dataset_tools/tools/fuyu_utils.py b/extensions_built_in/dataset_tools/tools/fuyu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..407da10c257b46bbd7c2ce70d4beebff0a7d5a89 --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/fuyu_utils.py @@ -0,0 +1,66 @@ +from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer + +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption +import torch +from PIL import Image + + +class FuyuImageProcessor: + def __init__(self, device='cuda'): + from transformers import FuyuProcessor, FuyuForCausalLM + self.device = device + self.model: FuyuForCausalLM = None + self.processor: FuyuProcessor = None + self.dtype = torch.bfloat16 + self.tokenizer: AutoTokenizer + self.is_loaded = False + + def load_model(self): + from transformers import FuyuProcessor, FuyuForCausalLM + model_path = "adept/fuyu-8b" + kwargs = {"device_map": self.device} + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=self.dtype, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + self.processor = FuyuProcessor.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + self.is_loaded = True + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs) + self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer) + + def generate_caption( + self, image: Image, + prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): + # prepare inputs for the model + # text_prompt = f"{prompt}\n" + + # image = image.convert('RGB') + model_inputs = self.processor(text=prompt, images=[image]) + model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in + model_inputs.items()} + + generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens) + prompt_len = model_inputs["input_ids"].shape[-1] + output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True) + output = clean_caption(output, replacements=replacements) + return output + + # inputs = self.processor(text=text_prompt, images=image, return_tensors="pt") + # for k, v in inputs.items(): + # inputs[k] = v.to(self.device) + + # # autoregressively generate text + # generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens) + # generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True) + # output = generation_text[0] + # + # return clean_caption(output, replacements=replacements) diff --git a/extensions_built_in/dataset_tools/tools/image_tools.py b/extensions_built_in/dataset_tools/tools/image_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d36073c0164047bdced3111a7230084e0b0bd187 --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/image_tools.py @@ -0,0 +1,49 @@ +from typing import Literal, Type, TYPE_CHECKING, Union + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch'] + +img_manipulation_steps = ['contrast_stretch'] + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + +if TYPE_CHECKING: + from .llava_utils import LLaVAImageProcessor + from .fuyu_utils import FuyuImageProcessor + +ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor'] + + +def pil_to_cv2(image): + """Convert a PIL image to a cv2 image.""" + return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + +def cv2_to_pil(image): + """Convert a cv2 image to a PIL image.""" + return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + +def load_image(img_path: str): + image = Image.open(img_path).convert('RGB') + try: + # transpose with exif data + image = ImageOps.exif_transpose(image) + except Exception as e: + pass + return image + + +def resize_to_max(image, max_width=1024, max_height=1024): + width, height = image.size + if width <= max_width and height <= max_height: + return image + + scale = min(max_width / width, max_height / height) + width = int(width * scale) + height = int(height * scale) + + return image.resize((width, height), Image.LANCZOS) diff --git a/extensions_built_in/dataset_tools/tools/llava_utils.py b/extensions_built_in/dataset_tools/tools/llava_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba38d6613c2ef5a8cd120d8906e5b8d236240bd --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/llava_utils.py @@ -0,0 +1,85 @@ + +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption + +import torch +from PIL import Image, ImageOps + +from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + + +class LLaVAImageProcessor: + def __init__(self, device='cuda'): + try: + from llava.model import LlavaLlamaForCausalLM + except ImportError: + # print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") + print( + "You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") + raise + self.device = device + self.model: LlavaLlamaForCausalLM = None + self.tokenizer: AutoTokenizer = None + self.image_processor: CLIPImageProcessor = None + self.is_loaded = False + + def load_model(self): + from llava.model import LlavaLlamaForCausalLM + + model_path = "4bit/llava-v1.5-13b-3GB" + # kwargs = {"device_map": "auto"} + kwargs = {"device_map": self.device} + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + vision_tower = self.model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device=self.device) + self.image_processor = vision_tower.image_processor + self.is_loaded = True + + def generate_caption( + self, image: + Image, prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): + from llava.conversation import conv_templates, SeparatorStyle + from llava.utils import disable_torch_init + from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria + # question = "how many dogs are in the picture?" + disable_torch_init() + conv_mode = "llava_v0" + conv = conv_templates[conv_mode].copy() + roles = conv.roles + image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda() + + inp = f"{roles[0]}: {prompt}" + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + raw_prompt = conv.get_prompt() + input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, + return_tensors='pt').unsqueeze(0).cuda() + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, images=image_tensor, do_sample=True, temperature=0.1, + max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], + top_p=0.8 + ) + outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + conv.messages[-1][-1] = outputs + output = outputs.rsplit('', 1)[0] + return clean_caption(output, replacements=replacements) diff --git a/extensions_built_in/dataset_tools/tools/sync_tools.py b/extensions_built_in/dataset_tools/tools/sync_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..143cc6bb93a9269eeafffe737f96302e2ca40787 --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/sync_tools.py @@ -0,0 +1,279 @@ +import os +import requests +import tqdm +from typing import List, Optional, TYPE_CHECKING + + +def img_root_path(img_id: str): + return os.path.dirname(os.path.dirname(img_id)) + + +if TYPE_CHECKING: + from .dataset_tools_config_modules import DatasetSyncCollectionConfig + +img_exts = ['.jpg', '.jpeg', '.webp', '.png'] + +class Photo: + def __init__( + self, + id, + host, + width, + height, + url, + filename + ): + self.id = str(id) + self.host = host + self.width = width + self.height = height + self.url = url + self.filename = filename + + +def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int): + if img_width > img_height: + scale = min_height / img_height + else: + scale = min_width / img_width + + new_width = int(img_width * scale) + new_height = int(img_height * scale) + + return new_width, new_height + + +def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]: + all_images = [] + next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos" + + while True: + response = requests.get(next_page, headers={ + "Authorization": f"{config.api_key}" + }) + response.raise_for_status() + data = response.json() + all_images.extend(data['media']) + if 'next_page' in data and data['next_page']: + next_page = data['next_page'] + else: + break + + photos = [] + for image in all_images: + new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height) + url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}" + filename = os.path.basename(image['src']['original']) + + photos.append(Photo( + id=image['id'], + host="pexels", + width=image['width'], + height=image['height'], + url=url, + filename=filename + )) + + return photos + + +def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]: + headers = { + # "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}" + "Authorization": f"Client-ID {config.api_key}" + } + # headers['Authorization'] = f"Bearer {token}" + + url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30" + response = requests.get(url, headers=headers) + response.raise_for_status() + res_headers = response.headers + # parse the link header to get the next page + # 'Link': '; rel="last", ; rel="next"' + has_next_page = False + if 'Link' in res_headers: + has_next_page = True + link_header = res_headers['Link'] + link_header = link_header.split(',') + link_header = [link.strip() for link in link_header] + link_header = [link.split(';') for link in link_header] + link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header] + link_header = {link[1]: link[0] for link in link_header} + + # get page number from last url + last_page = link_header['rel="last'] + last_page = last_page.split('?')[1] + last_page = last_page.split('&') + last_page = [param.split('=') for param in last_page] + last_page = {param[0]: param[1] for param in last_page} + last_page = int(last_page['page']) + + all_images = response.json() + + if has_next_page: + # assume we start on page 1, so we don't need to get it again + for page in tqdm.tqdm(range(2, last_page + 1)): + url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30" + response = requests.get(url, headers=headers) + response.raise_for_status() + all_images.extend(response.json()) + + photos = [] + for image in all_images: + new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height) + url = f"{image['urls']['raw']}&w={new_width}" + filename = f"{image['id']}.jpg" + + photos.append(Photo( + id=image['id'], + host="unsplash", + width=image['width'], + height=image['height'], + url=url, + filename=filename + )) + + return photos + + +def get_img_paths(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = os.listdir(dir_path) + # remove non image files + local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts] + # make full path + local_files = [os.path.join(dir_path, file) for file in local_files] + return local_files + + +def get_local_image_ids(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = get_img_paths(dir_path) + # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg' + return set([os.path.basename(file).split('.')[0] for file in local_files]) + + +def get_local_image_file_names(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = get_img_paths(dir_path) + # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg' + return set([os.path.basename(file) for file in local_files]) + + +def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024): + img_width = photo.width + img_height = photo.height + + if img_width < min_width or img_height < min_height: + raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}") + + img_response = requests.get(photo.url) + img_response.raise_for_status() + os.makedirs(dir_path, exist_ok=True) + + filename = os.path.join(dir_path, photo.filename) + with open(filename, 'wb') as file: + file.write(img_response.content) + + +def update_caption(img_path: str): + # if the caption is a txt file, convert it to a json file + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + # see if it exists + if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")): + # todo add poi and what not + return # we have a json file + caption = "" + # see if txt file exists + if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")): + # read it + with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file: + caption = file.read() + # write json file + with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file: + file.write(f'{{"caption": "{caption}"}}') + + # delete txt file + os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")) + + +# def equalize_img(img_path: str): +# input_path = img_path +# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path)) +# os.makedirs(os.path.dirname(output_path), exist_ok=True) +# process_img( +# img_path=input_path, +# output_path=output_path, +# equalize=True, +# max_size=2056, +# white_balance=False, +# gamma_correction=False, +# strength=0.6, +# ) + + +# def annotate_depth(img_path: str): +# # make fake args +# args = argparse.Namespace() +# args.annotator = "midas" +# args.res = 1024 +# +# img = cv2.imread(img_path) +# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# +# output = annotate(img, args) +# +# output = output.astype('uint8') +# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) +# +# os.makedirs(os.path.dirname(img_path), exist_ok=True) +# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path)) +# +# cv2.imwrite(output_path, output) + + +# def invert_depth(img_path: str): +# img = cv2.imread(img_path) +# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# # invert the colors +# img = cv2.bitwise_not(img) +# +# os.makedirs(os.path.dirname(img_path), exist_ok=True) +# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path)) +# cv2.imwrite(output_path, img) + + + # + # # update our list of raw images + # raw_images = get_img_paths(raw_dir) + # + # # update raw captions + # for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"): + # update_caption(image_id) + # + # # equalize images + # for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"): + # if img_path not in eq_images: + # equalize_img(img_path) + # + # # update our list of eq images + # eq_images = get_img_paths(eq_dir) + # # update eq captions + # for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"): + # update_caption(image_id) + # + # # annotate depth + # depth_dir = os.path.join(root_dir, DEPTH_DIR) + # depth_images = get_img_paths(depth_dir) + # for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"): + # if img_path not in depth_images: + # annotate_depth(img_path) + # + # depth_images = get_img_paths(depth_dir) + # + # # invert depth + # inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR) + # inv_depth_images = get_img_paths(inv_depth_dir) + # for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"): + # if img_path not in inv_depth_images: + # invert_depth(img_path) diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..11fa8a9eef13c5546bc2de5723301b4d4838a0cc --- /dev/null +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -0,0 +1,235 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random +from toolkit.basic import value_map + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ReferenceSliderConfig: + def __init__(self, **kwargs): + self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] + + +class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {})) + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") + config = { + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, + 'pos_weight': dataset.pos_weight, + 'neg_weight': dataset.neg_weight, + 'pos_folder': dataset.pos_folder, + 'neg_folder': dataset.neg_folder, + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + self.load_datasets() + + pass + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts, network_weights = batch + network_pos_weight, network_neg_weight = network_weights + + if isinstance(network_pos_weight, torch.Tensor): + network_pos_weight = network_pos_weight.item() + if isinstance(network_neg_weight, torch.Tensor): + network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + loss_jitter_multiplier = 1.0 + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + orig_network_pos_weight = network_pos_weight + network_pos_weight += jitter_list + network_neg_weight += (jitter_list * -1.0) + # penalize the loss for its distance from network_pos_weight + # a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter + # so the loss_jitter_multiplier needs to be 0.4 + loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0) + + + # if items in network_weight list are tensors, convert them to floats + + dtype = get_torch_dtype(self.train_config.dtype) + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise_positive = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + noise_negative = noise_positive.clone() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # if training text encoder enable grads, else do context of no grad + with torch.set_grad_enabled(self.train_config.train_text_encoder): + # fix issue with them being tuples sometimes + prompt_list = [] + for prompt in prompts: + if isinstance(prompt, tuple): + prompt = prompt[0] + prompt_list.append(prompt) + conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype) + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + + # if self.model_config.is_xl: + # # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + # network_multiplier_list = network_multiplier + # noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + # noise_list = torch.chunk(noise, 2, dim=0) + # timesteps_list = torch.chunk(timesteps, 2, dim=0) + # conditional_embeds_list = split_prompt_embeds(conditional_embeds) + # else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] + + losses = [] + # allow to chunk it out to save vram + for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( + network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list + ): + with self.network: + assert self.network.is_active + + self.network.multiplier = network_multiplier + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() * loss_jitter_multiplier + + loss_float = loss.item() + losses.append(loss_float) + + # back propagate loss to free ram + loss.backward() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0} + ) + + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/image_reference_slider_trainer/__init__.py b/extensions_built_in/image_reference_slider_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a15f646bde32a68d194838c4c293619caa8bf93 --- /dev/null +++ b/extensions_built_in/image_reference_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ImageReferenceSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "image_reference_slider_trainer" + + # name is the name of the extension for printing + name = "Image Reference Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess + return ImageReferenceSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ImageReferenceSliderTrainer +] diff --git a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b0f4734ae09fb7e942e33089014ffe59cfd7720 --- /dev/null +++ b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4d051d9d2b74148cf3bc0e5dc1e3fd5c03f5e3 --- /dev/null +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -0,0 +1,1679 @@ +import os +import random +from collections import OrderedDict +from typing import Union, Literal, List, Optional + +import numpy as np +from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel + +import torch.functional as F +from safetensors.torch import load_file +from torch.utils.data import DataLoader, ConcatDataset + +from toolkit import train_tools +from toolkit.basic import value_map, adain, get_mean_std +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.config_modules import GuidanceConfig +from toolkit.data_loader import get_dataloader_datasets +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType +from toolkit.image_utils import show_tensors, show_latents +from toolkit.ip_adapter import IPAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ + apply_learnable_snr_gos, LearnableSNRGamma +import gc +import torch +from jobs.process import BaseSDTrainProcess +from torchvision import transforms +from diffusers import EMAModel +import math +from toolkit.train_tools import precondition_model_outputs_flow_match + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class SDTrainer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] + self.do_prior_prediction = False + self.do_long_prompts = False + self.do_guided_loss = False + self.taesd: Optional[AutoencoderTiny] = None + + self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None + + self.scaler = torch.cuda.amp.GradScaler() + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + self.do_grad_scale = True + if self.is_fine_tuning and self.is_bfloat: + self.do_grad_scale = False + if self.adapter_config is not None: + if self.adapter_config.train: + self.do_grad_scale = False + + if self.train_config.dtype in ["fp16", "float16"]: + # patch the scaler to allow fp16 training + org_unscale_grads = self.scaler._unscale_grads_ + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.cached_blank_embeds: Optional[PromptEmbeds] = None + self.cached_trigger_embeds: Optional[PromptEmbeds] = None + + + def before_model_load(self): + pass + + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + if self.train_config.train_turbo and self.train_config.show_turbo_outputs: + if self.model_config.is_xl: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + else: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) + self.taesd.eval() + self.taesd.requires_grad_(False) + + def hook_before_train_loop(self): + super().hook_before_train_loop() + + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') + flush() + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + if self.adapter is not None: + self.adapter.to(self.device_torch) + + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + + if self.train_config.negative_prompt is not None: + if os.path.exists(self.train_config.negative_prompt): + with open(self.train_config.negative_prompt, 'r') as f: + self.negative_prompt_pool = f.readlines() + # remove empty + self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] + else: + # single prompt + self.negative_prompt_pool = [self.train_config.negative_prompt] + + # handle unload text encoder + if self.train_config.unload_text_encoder: + with torch.no_grad(): + if self.train_config.train_text_encoder: + raise ValueError("Cannot unload text encoder if training text encoder") + # cache embeddings + + print("\n***** UNLOADING TEXT ENCODER *****") + print("This will train only with a blank prompt or trigger word, if set") + print("If this is not what you want, remove the unload_text_encoder flag") + print("***********************************") + print("") + self.sd.text_encoder_to(self.device_torch) + self.cached_blank_embeds = self.sd.encode_prompt("") + if self.trigger_word is not None: + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) + + # move back to cpu + self.sd.text_encoder_to('cpu') + flush() + + + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): + # to process turbo learning, we make one big step from our current timestep to the end + # we then denoise the prediction on that remaining step and target our loss to our target latents + # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. + # needs to be done on each item in batch as they may all have different timesteps + batch_size = pred.shape[0] + pred_chunks = torch.chunk(pred, batch_size, dim=0) + noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) + timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) + latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) + noise_chunks = torch.chunk(noise, batch_size, dim=0) + + with torch.no_grad(): + # set the timesteps to 1000 so we can capture them to calculate the sigmas + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() + + train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() + + # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step + self.sd.noise_scheduler.set_timesteps( + 1, + device=self.device_torch + ) + + denoised_pred_chunks = [] + target_pred_chunks = [] + + for i in range(batch_size): + pred_item = pred_chunks[i] + noisy_latents_item = noisy_latents_chunks[i] + timesteps_item = timesteps_chunks[i] + latents_item = latent_chunks[i] + noise_item = noise_chunks[i] + with torch.no_grad(): + timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] + single_step_timestep_schedule = [timesteps_item.squeeze().item()] + # extract the sigma idx for our midpoint timestep + sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) + + end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) + end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) + + # add noise to our target + + # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step + # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() + self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() + # set our single timstep + self.sd.noise_scheduler.timesteps = torch.from_numpy( + np.array(single_step_timestep_schedule, dtype=np.float32) + ).to(device=self.device_torch) + + # set the step index to None so it will be recalculated on first step + self.sd.noise_scheduler._step_index = None + + denoised_latent = self.sd.noise_scheduler.step( + pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False + )[0] + + residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( + self.train_config.dtype)) + # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) + denoised_latent = denoised_latent - residual_noise + + denoised_pred_chunks.append(denoised_latent) + + denoised_latents = torch.cat(denoised_pred_chunks, dim=0) + # set the scheduler back to the original timesteps + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + + output = denoised_latents / self.sd.vae.config['scaling_factor'] + output = self.sd.vae.decode(output).sample + + if self.train_config.show_turbo_outputs: + # since we are completely denoising, we can show them here + with torch.no_grad(): + show_tensors(output) + + # we return our big partial step denoised latents as our pred and our untouched latents as our target. + # you can do mse against the two here or run the denoised through the vae for pixel space loss against the + # input tensor images. + + return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + + # you can expand these in a child class to make customization easier + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) + + prior_mask_multiplier = None + target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) + + has_mask = batch.mask_tensor is not None + + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + + target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) + + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std + + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust + + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() + + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + with torch.no_grad(): + # we need to make the noise prediction be a masked blending of noise and prior_pred + stretched_mask_multiplier = value_map( + mask_multiplier, + batch.file_items[0].dataset_config.mask_min_value, + 1.0, + 0.0, + 1.0 + ) + + prior_mask_multiplier = 1.0 - stretched_mask_multiplier + + + # target_mask_multiplier = mask_multiplier + # mask_multiplier = 1.0 + target = noise + # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) + # set masked multiplier to 1.0 so we dont double apply it + # mask_multiplier = 1.0 + elif prior_pred is not None and not self.train_config.do_prior_divergence: + assert not self.train_config.train_turbo + # matching adapter prediction + target = prior_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + + elif self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise + + if target is None: + target = noise + + pred = noise_pred + + if self.train_config.train_turbo: + pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + assert not self.train_config.train_turbo + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + + if self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + # handle linear timesteps and only adjust the weight of the timesteps + if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2): + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( + timesteps, + v2=self.train_config.linear_timesteps2 + ).to(loss.device, dtype=loss.dtype) + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + loss = loss * timestep_weight + + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + + # multiply by our mask + loss = loss * mask_multiplier + + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: + assert not self.train_config.train_turbo + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier + if torch.isnan(prior_loss).any(): + print("Prior loss is nan") + prior_loss = None + else: + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss + # loss = loss + prior_loss + loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + loss = loss * loss_multiplier + if prior_loss is not None: + loss = loss + prior_loss + + if not self.train_config.train_turbo: + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() + loss = loss + norm_std_loss + + + return loss + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return batch + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + unconditional_embeds=unconditional_embeds, + scaler=self.scaler, + **kwargs + ) + + return loss + + def get_guided_loss_targeted_polarity( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs + ): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(self.train_config.dtype) + + conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() + + mean_latents = (conditional_latents + unconditional_latents) / 2.0 + + unconditional_diff = (unconditional_latents - mean_latents) + conditional_diff = (conditional_latents - mean_latents) + + # we need to determine the amount of signal and noise that would be present at the current timestep + # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) + # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) + # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) + # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) + + # target_noise = noise + unconditional_signal + + conditional_noisy_latents = self.sd.add_noise( + mean_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = self.sd.add_noise( + mean_latents, + noise, + timesteps + ).detach() + + # Disable the LoRA network so we can predict parent network knowledge without it + self.network.is_active = False + self.sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + baseline_prediction = self.sd.predict_noise( + latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + # since we are dividing the polarity from the middle out, we need to double our network + # weights on training since the convergent point will be at half network strength + + negative_network_weights = [weight * -2.0 for weight in network_weight_list] + positive_network_weights = [weight * 2.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + self.sd.unet.train() + self.network.is_active = True + + self.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = self.sd.predict_noise( + latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + + pred_pos = pred_pos - baseline_prediction + pred_neg = pred_neg - baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + unconditional_diff.float(), + reduction="none" + ) + pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + conditional_diff.float(), + reduction="none" + ) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + loss = (pred_loss + pred_neg_loss) / 2.0 + + # loss = self.apply_snr(loss, timesteps) + loss = loss.mean() + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + def get_guided_loss_masked_polarity( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs + ): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(self.train_config.dtype) + + conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() + inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents) + + mean_latents = (conditional_latents + unconditional_latents) / 2.0 + + # unconditional_diff = (unconditional_latents - mean_latents) + # conditional_diff = (conditional_latents - mean_latents) + + # we need to determine the amount of signal and noise that would be present at the current timestep + # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) + # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) + # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) + # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) + + # make a differential mask + differential_mask = torch.abs(conditional_latents - unconditional_latents) + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + differential_scaler = 1.0 / max_differential + differential_mask = differential_mask * differential_scaler + spread_point = 0.1 + # adjust mask to amplify the differential at 0.1 + differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point + # clip it + differential_mask = torch.clamp(differential_mask, 0.0, 1.0) + + # target_noise = noise + unconditional_signal + + conditional_noisy_latents = self.sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = self.sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + inverse_noisy_latents = self.sd.add_noise( + inverse_latents, + noise, + timesteps + ).detach() + + # Disable the LoRA network so we can predict parent network knowledge without it + self.network.is_active = False + self.sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + # baseline_prediction = self.sd.predict_noise( + # latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + # conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + # timestep=timesteps, + # guidance_scale=1.0, + # **pred_kwargs # adapter residuals in here + # ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + # since we are dividing the polarity from the middle out, we need to double our network + # weights on training since the convergent point will be at half network strength + + negative_network_weights = [weight * -1.0 for weight in network_weight_list] + positive_network_weights = [weight * 1.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + self.sd.unet.train() + self.network.is_active = True + + self.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = self.sd.predict_noise( + latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + + # create a loss to balance the mean to 0 between the two predictions + differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0 + + # pred_pos = pred_pos - baseline_prediction + # pred_neg = pred_neg - baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + noise.float(), + reduction="none" + ) + # apply mask + pred_loss = pred_loss * (1.0 + differential_mask) + pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + noise.float(), + reduction="none" + ) + # apply inverse mask + pred_neg_loss = pred_neg_loss * (1.0 - differential_mask) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + # make a loss to balance to losses of the pos and neg so they are equal + # differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss) + # + # differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss + # + # # add a multiplier to balancing losses to make them the top priority + # differential_mean_loss = differential_mean_loss + + # remove the grads from the negative as it is only a balancing loss + # pred_neg_loss = pred_neg_loss.detach() + + # loss = pred_loss + pred_neg_loss + differential_mean_loss + loss = pred_loss + pred_neg_loss + + # loss = self.apply_snr(loss, timesteps) + loss = loss.mean() + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, + **kwargs + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) or + (isinstance(self.adapter, CustomAdapter)) + ): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active + self.adapter.is_active = False + + if self.train_config.unload_text_encoder: + raise ValueError("Prior predictions currently do not support unloading text encoder") + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: + prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + + for idx, prompt in enumerate(prompt_list): + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) + prompt_list[idx] = prompt + + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype).detach() + + # dont use network on this + # self.network.multiplier = 0.0 + self.sd.unet.eval() + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux: + # we need to remove the image embeds from the prompt except for flux + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + rescale_cfg=self.train_config.cfg_rescale, + **pred_kwargs # adapter residuals in here + ) + if was_unet_training: + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: + del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] + + if can_disable_adapter: + self.adapter.is_active = was_adapter_active + # restore network + # self.network.multiplier = network_weight_list + if self.network is not None: + self.network.is_active = was_network_active + return prior_pred + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=self.train_config.cfg_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, + **kwargs + ) + + def train_single_accumulation(self, batch: DataLoaderBatchDTO): + self.timer.start('preprocess_batch') + batch = self.preprocess_batch(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + self.adapter.num_control_images = 1 + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False + with torch.no_grad(): + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True + + adapter_images = None + sigmas = None + if has_adapter_img and (self.adapter or self.assistant_adapter): + with self.timer('get_adapter_images'): + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] + else: + raise NotImplementedError("Adapter images now must be loaded with dataloader") + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + if batch.mask_tensor is not None: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.5 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + + # flush() + with self.timer('grad_setup'): + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding is not None: + grad_on_text_encoder = True + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + + if self.adapter_config and self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits + if self.train_config.single_item_batching: + if self.model_config.refiner_name_or_path is not None: + raise ValueError("Single item batching is not supported when training the refiner") + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in prompts_1] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] + + else: + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [prompts_1] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + clip_images_list = [clip_images] + mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + clip_images_list, + mask_multiplier_list, + prompt_2_list + ): + + # if self.train_config.negative_prompt is not None: + # # add negative prompt + # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + # range(len(conditioned_prompts))] + # if prompt_2 is not None: + # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] + + with (network): + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True, + has_been_preprocessed=True, + drop=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + batch_size=noisy_latents.shape[0] + ) + + with self.timer('encode_prompt'): + unconditional_embeds = None + if self.train_config.unload_text_encoder: + with torch.set_grad_enabled(False): + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) + if self.train_config.do_cfg: + unconditional_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + unconditional_embeds = concat_prompt_embeds( + [unconditional_embeds] * noisy_latents.shape[0] + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + elif grad_on_text_encoder: + with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() + + if self.decorator: + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds + ) + if self.train_config.do_cfg: + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, + is_unconditional=True + ) + + # flush() + pred_kwargs = {} + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) + image_size = self.adapter.input_size + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image embeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: + # we will zero it out in the img embedder + clip_images = torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images, + drop=True, + is_training=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + drop=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + else: + print("No Clip Image") + print([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() + + with self.timer('encode_adapter'): + self.adapter.train() + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None + + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + + prior_pred = None + + do_reg_prior = False + # if is_reg and (self.network is not None or self.adapter is not None): + # # we are doing a reg image and we have a network or adapter + # do_reg_prior = True + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + with self.timer('prior predict'): + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + quad_count = random.randint(1, 4) + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True, + quad_count=quad_count + ) + + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + self.before_unet_predict() + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None or self.do_guided_loss: + # do guided loss + loss = self.get_guided_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + + else: + with self.timer('predict_unet'): + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + **pred_kwargs + ) + self.after_unet_predict() + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + # check if nan + if torch.isnan(loss): + print("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) + + with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + if not self.do_grad_scale: + loss.backward() + else: + self.scaler.scale(loss).backward() + + return loss.detach() + # flush() + + def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + total_loss = None + self.optimizer.zero_grad() + for batch in batch_list: + loss = self.train_single_accumulation(batch) + if total_loss is None: + total_loss = loss + else: + total_loss += loss + if len(batch_list) > 1 and self.model_config.low_vram: + torch.cuda.empty_cache() + + + if not self.is_grad_accumulation_step: + # fix this for multi params + if self.train_config.optimizer != 'adafactor': + if self.do_grad_scale: + self.scaler.unscale_(self.optimizer) + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + # self.optimizer.step() + if not self.do_grad_scale: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + + self.optimizer.zero_grad(set_to_none=True) + if self.adapter and isinstance(self.adapter, CustomAdapter): + self.adapter.post_weight_update() + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() + else: + # gradient accumulation. Just a place for breakpoint + pass + + # TODO Should we only step scheduler on grad step? If so, need to recalculate last step + with self.timer('scheduler_step'): + self.lr_scheduler.step() + + if self.embedding is not None: + with self.timer('restore_embeddings'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + + self.end_of_training_loop() + + return loss_dict diff --git a/extensions_built_in/sd_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45aa841ebcb0287a5309017fb9ab17d425c581a1 --- /dev/null +++ b/extensions_built_in/sd_trainer/__init__.py @@ -0,0 +1,30 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class SDTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "sd_trainer" + + # name is the name of the extension for printing + name = "SD Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SDTrainer import SDTrainer + return SDTrainer + + +# for backwards compatability +class TextualInversionTrainer(SDTrainerExtension): + uid = "textual_inversion_trainer" + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + SDTrainerExtension, TextualInversionTrainer +] diff --git a/extensions_built_in/sd_trainer/config/train.example.yaml b/extensions_built_in/sd_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/extensions_built_in/sd_trainer/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..857cfa7578f29ff1343f0650209f81b6720d4ce0 --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py @@ -0,0 +1,533 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds, build_latent_image_batch_for_prompt_pair +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random + +import random +from collections import OrderedDict +from tqdm import tqdm + +from toolkit.config_modules import SliderConfig +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs + +import torch + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class UltimateSliderConfig(SliderConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.img_loss_weight: float = kwargs.get('img_loss_weight', 1.0) + self.cfg_loss_weight: float = kwargs.get('cfg_loss_weight', 1.0) + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] + + +class UltimateSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = UltimateSliderConfig(**self.get_conf('slider', {})) + + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 + + # store a list of all the prompts from the dataset so we can cache it + self.dataset_prompts = [] + self.train_with_dataset = self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0 + + def load_datasets(self): + if self.data_loader is None and \ + self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0: + print(f"Loading datasets") + datasets = [] + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") + config = { + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, + 'pos_weight': dataset.pos_weight, + 'neg_weight': dataset.neg_weight, + 'pos_folder': dataset.pos_folder, + 'neg_folder': dataset.neg_folder, + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + # capture all the prompts from it so we can cache the embeds + self.dataset_prompts += image_dataset.get_all_prompts() + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + # load any datasets if they were passed + self.load_datasets() + + # read line by line from file + if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Found {len(self.prompt_txt_list)} prompts.") + + if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + cache = PromptEmbedsCache() + + # get encoded latents for our prompts + with torch.no_grad(): + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # trim to max steps if max steps is lower than prompt count + prompts_to_cache = prompts_to_cache[:self.train_config.steps] + + if len(self.dataset_prompts) > 0: + # add the prompts from the dataset + prompts_to_cache += self.dataset_prompts + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) + + prompt_pairs = [] + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): + for target in self.slider_config.targets: + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + # end hook_before_train_loop + + # move vae to device so we can encode on the fly + # todo cache latents + self.sd.vae.to(self.device_torch) + self.sd.vae.eval() + self.sd.vae.requires_grad_(False) + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + with torch.no_grad(): + ### LOOP SETUP ### + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + ### TARGET_PROMPTS ### + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) + + ### PREP REFERENCE IMAGES ### + + imgs, prompts, network_weights = batch + network_pos_weight, network_neg_weight = network_weights + + if isinstance(network_pos_weight, torch.Tensor): + network_pos_weight = network_pos_weight.item() + if isinstance(network_neg_weight, torch.Tensor): + network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + network_pos_weight += jitter_list + network_neg_weight += (jitter_list * -1.0) + + # if items in network_weight list are tensors, convert them to floats + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + current_timestep_index = timesteps.item() + current_timestep = noise_scheduler.timesteps[current_timestep_index] + timesteps = timesteps.long() + + # get noise + noise_positive = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + noise_negative = noise_positive.clone() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + ### CFG SLIDER TRAINING PREP ### + + # get CFG txt latents + noisy_cfg_latents = build_latent_image_batch_for_prompt_pair( + pos_latent=noisy_positive_latents, + neg_latent=noisy_negative_latents, + prompt_pair=prompt_pair, + prompt_chunk_size=self.prompt_chunk_size, + ) + noisy_cfg_latents.requires_grad = False + + assert not self.network.is_active + + # 4.20 GB RAM for 512x512 + positive_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.negative_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + positive_latents.requires_grad = False + + neutral_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.empty_prompt, # positive prompt (normally neutral + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + neutral_latents.requires_grad = False + + unconditional_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.positive_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + unconditional_latents.requires_grad = False + + positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) + prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) + noisy_cfg_latents_chunks = torch.chunk(noisy_cfg_latents, self.prompt_chunk_size, dim=0) + assert len(prompt_pair_chunks) == len(noisy_cfg_latents_chunks) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] + + flush() + + loss_float = None + loss_mirror_float = None + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # TODO allow both processed to train text encoder, for now, we just to unet and cache all text encodes + # if training text encoder enable grads, else do context of no grad + # with torch.set_grad_enabled(self.train_config.train_text_encoder): + # # text encoding + # embedding_list = [] + # # embed the prompts + # for prompt in prompts: + # embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + # embedding_list.append(embedding) + # conditional_embeds = concat_prompt_embeds(embedding_list) + # conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + + if self.train_with_dataset: + embedding_list = [] + with torch.set_grad_enabled(self.train_config.train_text_encoder): + for prompt in prompts: + # get embedding form cache + embedding = self.prompt_cache[prompt] + embedding = embedding.to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + # double up so we can do both sides of the slider + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + else: + # throw error. Not supported yet + raise Exception("Datasets and targets required for ultimate slider") + + if self.model_config.is_xl: + # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + network_multiplier_list = network_multiplier + noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + noise_list = torch.chunk(noise, 2, dim=0) + timesteps_list = torch.chunk(timesteps, 2, dim=0) + conditional_embeds_list = split_prompt_embeds(conditional_embeds) + else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] + + ## DO REFERENCE IMAGE TRAINING ## + + reference_image_losses = [] + # allow to chunk it out to save vram + for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( + network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list + ): + with self.network: + assert self.network.is_active + + self.network.multiplier = network_multiplier + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + # todo add snr gamma here + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + loss = loss * self.slider_config.img_loss_weight + loss_slide_float = loss.item() + + loss_float = loss.item() + reference_image_losses.append(loss_float) + + # back propagate loss to free ram + loss.backward() + flush() + + ## DO CFG SLIDER TRAINING ## + + cfg_loss_list = [] + + with self.network: + assert self.network.is_active + for prompt_pair_chunk, \ + noisy_cfg_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk \ + in zip( + prompt_pair_chunks, + noisy_cfg_latents_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + + target_latents = self.sd.predict_noise( + latents=noisy_cfg_latent_chunk, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair_chunk.positive_target, # negative prompt + prompt_pair_chunk.target_class, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) + + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] + + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier + + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset + + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + + loss = loss.mean() * prompt_pair_chunk.weight * self.slider_config.cfg_loss_weight + + loss.backward() + cfg_loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + flush() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + # reset network + self.network.multiplier = 1.0 + + reference_image_loss = sum(reference_image_losses) / len(reference_image_losses) if len( + reference_image_losses) > 0 else 0.0 + cfg_loss = sum(cfg_loss_list) / len(cfg_loss_list) if len(cfg_loss_list) > 0 else 0.0 + + loss_dict = OrderedDict({ + 'loss/img': reference_image_loss, + 'loss/cfg': cfg_loss, + }) + + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/ultimate_slider_trainer/__init__.py b/extensions_built_in/ultimate_slider_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7006db52a5882713dd071b683cb5892c2a0d00 --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class UltimateSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "ultimate_slider_trainer" + + # name is the name of the extension for printing + name = "Ultimate Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess + return UltimateSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + UltimateSliderTrainer +] diff --git a/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml b/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b0f4734ae09fb7e942e33089014ffe59cfd7720 --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/flux_train_ui.py b/flux_train_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..54411d58c0beffbd185930d6da090f09bd660c4a --- /dev/null +++ b/flux_train_ui.py @@ -0,0 +1,414 @@ +import os +from huggingface_hub import whoami +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys + +# Add the current working directory to the Python path +sys.path.insert(0, os.getcwd()) + +import gradio as gr +from PIL import Image +import torch +import uuid +import os +import shutil +import json +import yaml +from slugify import slugify +from transformers import AutoProcessor, AutoModelForCausalLM + +sys.path.insert(0, "ai-toolkit") +from toolkit.job import get_job + +MAX_IMAGES = 150 + +def load_captioning(uploaded_files, concept_sentence): + uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] + txt_files = [file for file in uploaded_files if file.endswith('.txt')] + txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} + updates = [] + if len(uploaded_images) <= 1: + raise gr.Error( + "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)" + ) + elif len(uploaded_images) > MAX_IMAGES: + raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") + # Update for the captioning_area + # for _ in range(3): + updates.append(gr.update(visible=True)) + # Update visibility and image for each captioning row and image + for i in range(1, MAX_IMAGES + 1): + # Determine if the current row and image should be visible + visible = i <= len(uploaded_images) + + # Update visibility of the captioning row + updates.append(gr.update(visible=visible)) + + # Update for image component - display image if available, otherwise hide + image_value = uploaded_images[i - 1] if visible else None + updates.append(gr.update(value=image_value, visible=visible)) + + corresponding_caption = False + if(image_value): + base_name = os.path.splitext(os.path.basename(image_value))[0] + print(base_name) + print(image_value) + if base_name in txt_files_dict: + print("entrou") + with open(txt_files_dict[base_name], 'r') as file: + corresponding_caption = file.read() + + # Update value of captioning area + text_value = corresponding_caption if visible and corresponding_caption else "[trigger]" if visible and concept_sentence else None + updates.append(gr.update(value=text_value, visible=visible)) + + # Update for the sample caption area + updates.append(gr.update(visible=True)) + # Update prompt samples + updates.append(gr.update(placeholder=f'A portrait of person in a bustling cafe {concept_sentence}', value=f'A person in a bustling cafe {concept_sentence}')) + updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}")) + updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall")) + updates.append(gr.update(visible=True)) + return updates + +def hide_captioning(): + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + +def create_dataset(*inputs): + print("Creating dataset") + images = inputs[0] + destination_folder = str(f"datasets/{uuid.uuid4()}") + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl") + with open(jsonl_file_path, "a") as jsonl_file: + for index, image in enumerate(images): + new_image_path = shutil.copy(image, destination_folder) + + original_caption = inputs[index + 1] + file_name = os.path.basename(new_image_path) + + data = {"file_name": file_name, "prompt": original_caption} + + jsonl_file.write(json.dumps(data) + "\n") + + return destination_folder + + +def run_captioning(images, concept_sentence, *captions): + #Load internally to not consume resources for training + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 + model = AutoModelForCausalLM.from_pretrained( + "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) + + captions = list(captions) + for i, image_path in enumerate(images): + print(captions[i]) + if isinstance(image_path, str): # If image is a file path + image = Image.open(image_path).convert("RGB") + + prompt = "" + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) + + generated_ids = model.generate( + input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 + ) + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = processor.post_process_generation( + generated_text, task=prompt, image_size=(image.width, image.height) + ) + caption_text = parsed_answer[""].replace("The image shows ", "") + if concept_sentence: + caption_text = f"{caption_text} [trigger]" + captions[i] = caption_text + + yield captions + model.to("cpu") + del model + del processor + +def recursive_update(d, u): + for k, v in u.items(): + if isinstance(v, dict) and v: + d[k] = recursive_update(d.get(k, {}), v) + else: + d[k] = v + return d + +def start_training( + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options, +): + push_to_hub = True + if not lora_name: + raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.") + try: + if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]: + gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.") + else: + push_to_hub = False + gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face") + except: + push_to_hub = False + gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face") + + print("Started training") + slugged_lora_name = slugify(lora_name) + + # Load the default config + with open("config/examples/train_lora_flux_24gb.yaml", "r") as f: + config = yaml.safe_load(f) + + # Update the config with user inputs + config["config"]["name"] = slugged_lora_name + config["config"]["process"][0]["model"]["low_vram"] = low_vram + config["config"]["process"][0]["train"]["skip_first_sample"] = True + config["config"]["process"][0]["train"]["steps"] = int(steps) + config["config"]["process"][0]["train"]["lr"] = float(lr) + config["config"]["process"][0]["network"]["linear"] = int(rank) + config["config"]["process"][0]["network"]["linear_alpha"] = int(rank) + config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder + config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub + if(push_to_hub): + try: + username = whoami()["name"] + except: + raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?") + config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}" + config["config"]["process"][0]["save"]["hf_private"] = True + if concept_sentence: + config["config"]["process"][0]["trigger_word"] = concept_sentence + + if sample_1 or sample_2 or sample_3: + config["config"]["process"][0]["train"]["disable_sampling"] = False + config["config"]["process"][0]["sample"]["sample_every"] = steps + config["config"]["process"][0]["sample"]["sample_steps"] = 28 + config["config"]["process"][0]["sample"]["prompts"] = [] + if sample_1: + config["config"]["process"][0]["sample"]["prompts"].append(sample_1) + if sample_2: + config["config"]["process"][0]["sample"]["prompts"].append(sample_2) + if sample_3: + config["config"]["process"][0]["sample"]["prompts"].append(sample_3) + else: + config["config"]["process"][0]["train"]["disable_sampling"] = True + if(model_to_train == "schnell"): + config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell" + config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter" + config["config"]["process"][0]["sample"]["sample_steps"] = 4 + if(use_more_advanced_options): + more_advanced_options_dict = yaml.safe_load(more_advanced_options) + config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict) + print(config) + + # Save the updated config + # generate a random name for the config + random_config_name = str(uuid.uuid4()) + os.makedirs("tmp", exist_ok=True) + config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + + # run the job locally + job = get_job(config_path) + job.run() + job.cleanup() + + return f"Training completed successfully. Model saved as {slugged_lora_name}" + +config_yaml = ''' +device: cuda:0 +model: + is_flux: true + quantize: true +network: + linear: 16 #it will overcome the 'rank' parameter + linear_alpha: 16 #you can have an alpha different than the ranking if you'd like + type: lora +sample: + guidance_scale: 3.5 + height: 1024 + neg: '' #doesn't work for FLUX + sample_every: 1000 + sample_steps: 28 + sampler: flowmatch + seed: 42 + walk_seed: true + width: 1024 +save: + dtype: float16 + hf_private: true + max_step_saves_to_keep: 4 + push_to_hub: true + save_every: 10000 +train: + batch_size: 1 + dtype: bf16 + ema_config: + ema_decay: 0.99 + use_ema: true + gradient_accumulation_steps: 1 + gradient_checkpointing: true + noise_scheduler: flowmatch + optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit + train_text_encoder: false #probably doesn't work for flux + train_unet: true +''' + +theme = gr.themes.Monochrome( + text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), + font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], +) +css = """ +h1{font-size: 2em} +h3{margin-top: 0} +#component-1{text-align:center} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.tabitem{border: 0px} +.group_padding{padding: .55em} +""" +with gr.Blocks(theme=theme, css=css) as demo: + gr.Markdown( + """# LoRA Ease for FLUX 🧞‍♂️ +### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)""" + ) + with gr.Column() as main_ui: + with gr.Row(): + lora_name = gr.Textbox( + label="The name of your LoRA", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + ) + concept_sentence = gr.Textbox( + label="Trigger word/sentence", + info="Trigger word or sentence to be used", + placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", + interactive=True, + ) + with gr.Group(visible=True) as image_upload: + with gr.Row(): + images = gr.File( + file_types=["image", ".txt"], + label="Upload your images", + file_count="multiple", + interactive=True, + visible=True, + scale=1, + ) + with gr.Column(scale=3, visible=False) as captioning_area: + with gr.Column(): + gr.Markdown( + """# Custom captioning +

You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.

+""", elem_classes="group_padding") + do_captioning = gr.Button("Add AI captions with Florence-2") + output_components = [captioning_area] + caption_list = [] + for i in range(1, MAX_IMAGES + 1): + locals()[f"captioning_row_{i}"] = gr.Row(visible=False) + with locals()[f"captioning_row_{i}"]: + locals()[f"image_{i}"] = gr.Image( + type="filepath", + width=111, + height=111, + min_width=111, + interactive=False, + scale=2, + show_label=False, + show_share_button=False, + show_download_button=False, + ) + locals()[f"caption_{i}"] = gr.Textbox( + label=f"Caption {i}", scale=15, interactive=True + ) + + output_components.append(locals()[f"captioning_row_{i}"]) + output_components.append(locals()[f"image_{i}"]) + output_components.append(locals()[f"caption_{i}"]) + caption_list.append(locals()[f"caption_{i}"]) + + with gr.Accordion("Advanced options", open=False): + steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1) + lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6) + rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4) + model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train") + low_vram = gr.Checkbox(label="Low VRAM", value=True) + with gr.Accordion("Even more advanced options", open=False): + use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False) + more_advanced_options = gr.Code(config_yaml, language="yaml") + + with gr.Accordion("Sample prompts (optional)", visible=False) as sample: + gr.Markdown( + "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)" + ) + sample_1 = gr.Textbox(label="Test prompt 1") + sample_2 = gr.Textbox(label="Test prompt 2") + sample_3 = gr.Textbox(label="Test prompt 3") + + output_components.append(sample) + output_components.append(sample_1) + output_components.append(sample_2) + output_components.append(sample_3) + start = gr.Button("Start training", visible=False) + output_components.append(start) + progress_area = gr.Markdown("") + + dataset_folder = gr.State() + + images.upload( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + + images.delete( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + + images.clear( + hide_captioning, + outputs=[captioning_area, sample, start] + ) + + start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then( + fn=start_training, + inputs=[ + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options + ], + outputs=progress_area, + ) + + do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) + +if __name__ == "__main__": + demo.launch(share=True, show_error=True) \ No newline at end of file diff --git a/hf_ui.py b/hf_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..df41f0236932484ea32c6f8d7a4a26f324481303 --- /dev/null +++ b/hf_ui.py @@ -0,0 +1,417 @@ +import os +from huggingface_hub import whoami +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys +import subprocess +import gradio as gr +import uuid +import os +import shutil +import json +import yaml +from slugify import slugify +from run_modal_from_hf import main + +# Add the current working directory to the Python path +sys.path.insert(0, os.getcwd()) +sys.path.insert(0, "ai-toolkit") + +MAX_IMAGES = 150 + +# Import app và main trực tiếp từ run_modal_from_hf + +def create_dataset(*inputs): + print("Creating dataset") + files = inputs[0] + destination_folder = str(f"datasets/{uuid.uuid4()}") + os.makedirs(destination_folder, exist_ok=True) + + if files is not None: + # Handle both single and multiple files + if not isinstance(files, list): + files = [files] # convert to a list if is not one + + # Phân loại files + image_files = [] + caption_files = [] + zip_files = [] + + for file in files: + ext = os.path.splitext(file.name)[1].lower() + if ext in ['.jpg', '.jpeg', '.png']: + image_files.append(file) + elif ext == '.txt': + caption_files.append(file) + elif ext == '.zip': + zip_files.append(file) + else: + raise ValueError(f"Unsupported file type: {ext}") + + # Validate số lượng files + if len(zip_files) > 1: + raise ValueError("Please upload only one zip file") + + if zip_files and (image_files or caption_files): + raise ValueError("Please upload either a zip file OR individual files, not both") + + # Copy files vào destination folder + for file in image_files + caption_files: + shutil.copy2(file.name, destination_folder) + + # Nếu có zip file, chỉ copy zip file (sẽ được xử lý bên Modal) + if zip_files: + shutil.copy2(zip_files[0].name, destination_folder) + + # Validate nếu là loose files + if image_files or caption_files: + validate_image_caption_pairs(destination_folder) + + return destination_folder + +def validate_image_caption_pairs(folder_path): + """Validate images và captions nếu được upload riêng lẻ""" + images = [] + captions = [] + + for file in os.listdir(folder_path): + name, ext = os.path.splitext(file) + ext = ext.lower() + + if ext in ['.jpg', '.jpeg', '.png']: + images.append(name) + elif ext == '.txt': + captions.append(name) + + # Kiểm tra nếu có caption thì phải match với images + if captions: + missing_captions = [] + for img in images: + if img not in captions: + missing_captions.append(img) + + if missing_captions: + raise ValueError(f"Missing captions for images: {', '.join(missing_captions)}") + +def recursive_update(d, u): + for k, v in u.items(): + if isinstance(v, dict) and v: + d[k] = recursive_update(d.get(k, {}), v) + else: + d[k] = v + return d + +def start_training( + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options, + push_to_hub, + use_wandb, +): + print("Starting training from gradio app") + + # build config + config = { + "job": "extension", + "config": { + "name": lora_name, + "process": [ + { + "type": "sd_trainer", + } + ] + }, + } + # build main config + config['config']['process'][0]['training_folder'] = "/root/ai-toolkit/modal_output" + config['config']['process'][0]['device'] = "cuda:0" + config['config']['process'][0]['network'] = { + "type": "lora", + "linear": int(rank), + "linear_alpha": int(rank) + } + config['config']['process'][0]['save'] = { + "dtype": "float16", + "save_every": int(steps), + "max_step_saves_to_keep": 4, + "push_to_hub": push_to_hub, + "hf_repo_id": f"test/{slugify(lora_name)}", + "hf_private": True + } + + config['config']['process'][0]['datasets'] = [{ + "folder_path": "/root/ai-toolkit/" + dataset_folder, # MUST match modal directory + "caption_ext": "txt", + "caption_dropout_rate": 0.05, + "shuffle_tokens": False, + "cache_latents_to_disk": True, + "resolution": [512, 768, 1024] + }] + + config['config']['process'][0]['train'] = { + "batch_size": 1, + "steps": int(steps), + "gradient_accumulation_steps": 1, + "train_unet": True, + "train_text_encoder": False, + "gradient_checkpointing": True, + "noise_scheduler": "flowmatch", + "optimizer": "adamw8bit", + "lr": float(lr), + "dtype": "bf16", + "ema_config": { + "use_ema": True, + "ema_decay": 0.99 + } + } + + config['config']['process'][0]['model'] = { + "name_or_path": "black-forest-labs/FLUX.1-dev", + "is_flux": True, + "quantize": True, + "low_vram": low_vram + } + config['config']['process'][0]['sample'] = { + "sampler": "flowmatch", + "sample_every": int(steps), + "width": 1024, + "height": 1024, + "prompts": [ + f"woman with red hair, playing chess at the park, bomb going off in the background {concept_sentence}", + f"a woman holding a coffee cup, in a beanie, sitting at a cafe {concept_sentence}", + f"a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini {concept_sentence}", + f"a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background {concept_sentence}", + f"a bear building a log cabin in the snow covered mountains {concept_sentence}", + f"woman playing the guitar, on stage, singing a song, laser lights, punk rocker {concept_sentence}", + f"hipster man with a beard, building a chair, in a wood shop {concept_sentence}", + f"photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop {concept_sentence}", + f"a man holding a sign that says, 'this is a sign' {concept_sentence}", + f"a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle {concept_sentence}" + ], + "neg": "", + "seed": 42, + "walk_seed": True, + "guidance_scale": 4, + "sample_steps": 20 + } + if sample_1 or sample_2 or sample_3: + config['config']['process'][0]["sample"]['prompts'] = [] + if sample_1: + config["config"]["process"][0]["sample"]["prompts"].append(sample_1) + if sample_2: + config["config"]["process"][0]["sample"]["prompts"].append(sample_2) + if sample_3: + config["config"]["process"][0]["sample"]["prompts"].append(sample_3) + + if concept_sentence: + config['config']['process'][0]['trigger_word'] = concept_sentence + + if(model_to_train == "schnell"): + config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell" + config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter" + config["config"]["process"][0]["sample"]["sample_steps"] = 4 + config["config"]["process"][0]["sample"]["guidance_scale"] = 1 + + if(use_more_advanced_options): + more_advanced_options_dict = yaml.safe_load(more_advanced_options) + config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict) + + # add wandb if needed + config['config']['process'][0]['logging'] = { + "log_every": 10, + "use_wandb": use_wandb, + "verbose": False + } + + # pass to modal function + config_file_list_str = json.dumps(config) + + try: + main.remote( + config_file_list_str=config_file_list_str, + recover=True, + name=lora_name + ) + return "Training started in Modal. Check your Modal dashboard for logs and status" + except Exception as e: + return f"Error starting training: {str(e)}" + +def setup_modal_token(token_command): + try: + # Tách command thành các phần + parts = token_command.strip().split() + if len(parts) == 6 and parts[0] == "modal" and parts[1] == "token" and parts[2] == "set": + token_id = parts[4] + token_secret = parts[6] + + # Thực thi lệnh + result = subprocess.run( + ["modal", "token", "set", "--token-id", token_id, "--token-secret", token_secret], + capture_output=True, + text=True + ) + + if result.returncode == 0: + return "Modal token đã được cấu hình thành công!" + else: + return f"Lỗi khi cấu hình token: {result.stderr}" + except Exception as e: + return f"Lỗi: {str(e)}" + +config_yaml = ''' +device: cuda:0 +model: + is_flux: true + quantize: true +network: + linear: 16 #it will overcome the 'rank' parameter + linear_alpha: 16 #you can have an alpha different than the ranking if you'd like + type: lora +sample: + guidance_scale: 3.5 + height: 1024 + neg: '' #doesn't work for FLUX + sample_every: 1000 + sample_steps: 28 + sampler: flowmatch + seed: 42 + walk_seed: true + width: 1024 +save: + dtype: float16 + hf_private: true + max_step_saves_to_keep: 4 + push_to_hub: true + save_every: 10000 +train: + batch_size: 1 + dtype: bf16 + ema_config: + ema_decay: 0.99 + use_ema: true + gradient_accumulation_steps: 1 + gradient_checkpointing: true + noise_scheduler: flowmatch + optimizer: adamw8bit + train_text_encoder: false #probably doesn't work for flux + train_unet: true +''' + +theme = gr.themes.Monochrome( + text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), + font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], +) +css = """ +h1{font-size: 2em} +h3{margin-top: 0} +#component-1{text-align:center} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.tabitem{border: 0px} +.group_padding{padding: .55em} +""" +with gr.Blocks(theme=theme, css=css) as demo: + gr.Markdown( + """# LoRA Ease for FLUX 🧞‍♂️ +### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)""" + ) + with gr.Column() as main_ui: + with gr.Row(): + lora_name = gr.Textbox( + label="The name of your LoRA", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + ) + concept_sentence = gr.Textbox( + label="Trigger word/sentence", + info="Trigger word or sentence to be used", + placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", + interactive=True, + ) + with gr.Group(visible=True) as image_upload: + with gr.Row(): + images = gr.File( + file_types=["image", ".txt", ".zip"], + label="Upload your dataset as zip or multiple images", + file_count="multiple", + interactive=True, + visible=True, + scale=1, + ) + + with gr.Accordion("Advanced options", open=False): + steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1) + lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6) + rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4) + model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train") + low_vram = gr.Checkbox(label="Low VRAM", value=True) + with gr.Accordion("Even more advanced options", open=False): + use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False) + more_advanced_options = gr.Code(config_yaml, language="yaml") + + with gr.Accordion("Sample prompts (optional)", visible=False) as sample: + gr.Markdown( + "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)" + ) + sample_1 = gr.Textbox(label="Test prompt 1") + sample_2 = gr.Textbox(label="Test prompt 2") + sample_3 = gr.Textbox(label="Test prompt 3") + + output_components = [sample, sample_1, sample_2, sample_3] + + with gr.Row(): + push_to_hub = gr.Checkbox(label="Push to Hub", value=True) + use_wandb = gr.Checkbox(label="Use WandB", value=False) + start = gr.Button("Start training") + output_components.append(start) + progress_area = gr.Markdown("") + output_components.append(progress_area) + + dataset_folder = gr.State() + + with gr.Accordion("Modal Configuration", open=False): + modal_token_input = gr.Textbox( + label="Nhập lệnh Modal token", + placeholder="modal token set --token-id YOUR_TOKEN_ID --token-secret YOUR_TOKEN_SECRET" + ) + modal_setup_btn = gr.Button("Setup Modal Token") + modal_status = gr.Markdown("") + + modal_setup_btn.click( + fn=setup_modal_token, + inputs=[modal_token_input], + outputs=[modal_status] + ) + + start.click(fn=create_dataset, inputs=[images], outputs=dataset_folder).then( + fn=start_training, + inputs=[ + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options, + push_to_hub, + use_wandb + ], + outputs=progress_area, + ) + +if __name__ == "__main__": + demo.launch(share=True, show_error=True) diff --git a/info.py b/info.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2f0a97403deb778f0549c5fed2f9972ac75209 --- /dev/null +++ b/info.py @@ -0,0 +1,8 @@ +from collections import OrderedDict + +v = OrderedDict() +v["name"] = "ai-toolkit" +v["repo"] = "https://github.com/ostris/ai-toolkit" +v["version"] = "0.1.0" + +software_meta = v diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py new file mode 100644 index 0000000000000000000000000000000000000000..8efd0097c6898cd8a6087fe9299f7e191f5a893a --- /dev/null +++ b/jobs/BaseJob.py @@ -0,0 +1,72 @@ +import importlib +from collections import OrderedDict +from typing import List + +from jobs.process import BaseProcess + + +class BaseJob: + + def __init__(self, config: OrderedDict): + if not config: + raise ValueError('config is required') + self.process: List[BaseProcess] + + self.config = config['config'] + self.raw_config = config + self.job = config['job'] + self.torch_profiler = self.get_conf('torch_profiler', False) + self.name = self.get_conf('name', required=True) + if 'meta' in config: + self.meta = config['meta'] + else: + self.meta = OrderedDict() + + def get_conf(self, key, default=None, required=False): + if key in self.config: + return self.config[key] + elif required: + raise ValueError(f'config file error. Missing "config.{key}" key') + else: + return default + + def run(self): + print("") + print(f"#############################################") + print(f"# Running job: {self.name}") + print(f"#############################################") + print("") + # implement in child class + # be sure to call super().run() first + pass + + def load_processes(self, process_dict: dict): + # only call if you have processes in this job type + if 'process' not in self.config: + raise ValueError('config file is invalid. Missing "config.process" key') + if len(self.config['process']) == 0: + raise ValueError('config file is invalid. "config.process" must be a list of processes') + + module = importlib.import_module('jobs.process') + + # add the processes + self.process = [] + for i, process in enumerate(self.config['process']): + if 'type' not in process: + raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') + + # check if dict key is process type + if process['type'] in process_dict: + if isinstance(process_dict[process['type']], str): + ProcessClass = getattr(module, process_dict[process['type']]) + else: + # it is the class + ProcessClass = process_dict[process['type']] + self.process.append(ProcessClass(i, self, process)) + else: + raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') + + def cleanup(self): + # if you implement this in child clas, + # be sure to call super().cleanup() LAST + del self diff --git a/jobs/ExtensionJob.py b/jobs/ExtensionJob.py new file mode 100644 index 0000000000000000000000000000000000000000..def4f8530a8a92c65369cd63a3e69c16bf0bb7de --- /dev/null +++ b/jobs/ExtensionJob.py @@ -0,0 +1,22 @@ +import os +from collections import OrderedDict +from jobs import BaseJob +from toolkit.extension import get_all_extensions_process_dict +from toolkit.paths import CONFIG_ROOT + +class ExtensionJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + self.process_dict = get_all_extensions_process_dict() + self.load_processes(self.process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py new file mode 100644 index 0000000000000000000000000000000000000000..d710d4128db5304569357ee05d2fb31fa15c6e39 --- /dev/null +++ b/jobs/ExtractJob.py @@ -0,0 +1,58 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype + +process_dict = { + 'locon': 'ExtractLoconProcess', + 'lora': 'ExtractLoraProcess', +} + + +class ExtractJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.base_model_path = self.get_conf('base_model', required=True) + self.model_base = None + self.model_base_text_encoder = None + self.model_base_vae = None + self.model_base_unet = None + self.extract_model_path = self.get_conf('extract_model', required=True) + self.model_extract = None + self.model_extract_text_encoder = None + self.model_extract_vae = None + self.model_extract_unet = None + self.extract_unet = self.get_conf('extract_unet', True) + self.extract_text_encoder = self.get_conf('extract_text_encoder', True) + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.output_folder = self.get_conf('output_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + # load models + print(f"Loading models for extraction") + print(f" - Loading base model: {self.base_model_path}") + # (text_model, vae, unet) + self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + self.model_base_text_encoder = self.model_base[0] + self.model_base_vae = self.model_base[1] + self.model_base_unet = self.model_base[2] + + print(f" - Loading extract model: {self.extract_model_path}") + self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) + self.model_extract_text_encoder = self.model_extract[0] + self.model_extract_vae = self.model_extract[1] + self.model_extract_unet = self.model_extract[2] + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/GenerateJob.py b/jobs/GenerateJob.py new file mode 100644 index 0000000000000000000000000000000000000000..ab61701a3bc5e3a0f21e67c5aeee7c67b5e8f7c2 --- /dev/null +++ b/jobs/GenerateJob.py @@ -0,0 +1,31 @@ +from jobs import BaseJob +from collections import OrderedDict +from typing import List +from jobs.process import GenerateProcess +from toolkit.paths import REPOS_ROOT + +import sys + +sys.path.append(REPOS_ROOT) + +process_dict = { + 'to_folder': 'GenerateProcess', +} + + +class GenerateJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/MergeJob.py b/jobs/MergeJob.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e3b87b9ff589438d06c56019446f06efb76cda --- /dev/null +++ b/jobs/MergeJob.py @@ -0,0 +1,29 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype + +process_dict = { +} + + +class MergeJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/ModJob.py b/jobs/ModJob.py new file mode 100644 index 0000000000000000000000000000000000000000..e37990de95a0d2ad78a94f9cdfd6dfbda0cdc529 --- /dev/null +++ b/jobs/ModJob.py @@ -0,0 +1,28 @@ +import os +from collections import OrderedDict +from jobs import BaseJob +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + +process_dict = { + 'rescale_lora': 'ModRescaleLoraProcess', +} + + +class ModJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py new file mode 100644 index 0000000000000000000000000000000000000000..dda64e2d94171cf53dd02b279b0e0456dc013e09 --- /dev/null +++ b/jobs/TrainJob.py @@ -0,0 +1,49 @@ +import json +import os + +from jobs import BaseJob +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from typing import List +from jobs.process import BaseExtractProcess, TrainFineTuneProcess +from datetime import datetime +import yaml +from toolkit.paths import REPOS_ROOT + +import sys + +sys.path.append(REPOS_ROOT) + +process_dict = { + 'vae': 'TrainVAEProcess', + 'slider': 'TrainSliderProcess', + 'slider_old': 'TrainSliderProcessOld', + 'lora_hack': 'TrainLoRAHack', + 'rescale_sd': 'TrainSDRescaleProcess', + 'esrgan': 'TrainESRGANProcess', + 'reference': 'TrainReferenceProcess', +} + + +class TrainJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.training_folder = self.get_conf('training_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + # self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) + # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 + self.log_dir = self.get_conf('log_dir', None) + + # loads the processes from the config + self.load_processes(process_dict) + + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da6c22b1ddd0ea9248e5afbf9b2ba014c137c1a --- /dev/null +++ b/jobs/__init__.py @@ -0,0 +1,7 @@ +from .BaseJob import BaseJob +from .ExtractJob import ExtractJob +from .TrainJob import TrainJob +from .MergeJob import MergeJob +from .ModJob import ModJob +from .GenerateJob import GenerateJob +from .ExtensionJob import ExtensionJob diff --git a/jobs/process/BaseExtensionProcess.py b/jobs/process/BaseExtensionProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..b53dc1c498e64bb4adbc2b967b329fdc4a374925 --- /dev/null +++ b/jobs/process/BaseExtensionProcess.py @@ -0,0 +1,19 @@ +from collections import OrderedDict +from typing import ForwardRef +from jobs.process.BaseProcess import BaseProcess + + +class BaseExtensionProcess(BaseProcess): + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None + + def run(self): + super().run() diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..ac10da54d82f15c8264b2799b10b01bb5cf8dc66 --- /dev/null +++ b/jobs/process/BaseExtractProcess.py @@ -0,0 +1,86 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors + +from typing import ForwardRef + +from toolkit.train_tools import get_torch_dtype + + +class BaseExtractProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.config: OrderedDict + self.output_folder: str + self.output_filename: str + self.output_path: str + self.process_id = process_id + self.job = job + self.config = config + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) + self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet) + self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder) + + def run(self): + # here instead of init because child init needs to go first + self.output_path = self.get_output_path() + # implement in child class + # be sure to call super().run() first + pass + + # you can override this in the child class if you want + # call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this + def get_output_path(self, prefix=None, suffix=None): + config_output_path = self.get_conf('output_path', None) + config_filename = self.get_conf('filename', None) + # replace [name] with name + + if config_output_path is not None: + config_output_path = config_output_path.replace('[name]', self.job.name) + return config_output_path + + if config_output_path is None and config_filename is not None: + # build the output path from the output folder and filename + return os.path.join(self.job.output_folder, config_filename) + + # build our own + + if suffix is None: + # we will just add process it to the end of the filename if there is more than one process + # and no other suffix was given + suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else '' + + if prefix is None: + prefix = '' + + output_filename = f"{prefix}{self.output_filename}{suffix}" + + return os.path.join(self.job.output_folder, output_filename) + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/BaseMergeProcess.py b/jobs/process/BaseMergeProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..55dfec68ae62383afae539ff6cb51862033a7e10 --- /dev/null +++ b/jobs/process/BaseMergeProcess.py @@ -0,0 +1,46 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + + +class BaseMergeProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.output_path = self.get_conf('output_path', required=True) + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..f064460750e6bf845e3ecf4fa9ebf476d47eb162 --- /dev/null +++ b/jobs/process/BaseProcess.py @@ -0,0 +1,61 @@ +import copy +import json +from collections import OrderedDict + +from toolkit.timer import Timer + + +class BaseProcess(object): + + def __init__( + self, + process_id: int, + job: 'BaseJob', + config: OrderedDict + ): + self.process_id = process_id + self.meta: OrderedDict + self.job = job + self.config = config + self.raw_process_config = config + self.name = self.get_conf('name', self.job.name) + self.meta = copy.deepcopy(self.job.meta) + self.timer: Timer = Timer(f'{self.name} Timer') + self.performance_log_every = self.get_conf('performance_log_every', 0) + + print(json.dumps(self.config, indent=4)) + + def get_conf(self, key, default=None, required=False, as_type=None): + # split key by '.' and recursively get the value + keys = key.split('.') + + # see if it exists in the config + value = self.config + for subkey in keys: + if subkey in value: + value = value[subkey] + else: + value = None + break + + if value is not None: + if as_type is not None: + value = as_type(value) + return value + elif required: + raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') + else: + if as_type is not None and default is not None: + return as_type(default) + return default + + def run(self): + # implement in child class + # be sure to call super().run() first incase something is added here + pass + + def add_meta(self, additional_meta: OrderedDict): + self.meta.update(additional_meta) + + +from jobs import BaseJob diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..87984717bd648f9b1b0d21c3c73b3e870f908ad8 --- /dev/null +++ b/jobs/process/BaseSDTrainProcess.py @@ -0,0 +1,2105 @@ +import copy +import glob +import inspect +import json +import random +import shutil +from collections import OrderedDict +import os +import re +from typing import Union, List, Optional + +import numpy as np +import yaml +from diffusers import T2IAdapter, ControlNetModel +from diffusers.training_utils import compute_density_for_timestep_sampling +from safetensors.torch import save_file, load_file +# from lycoris.config import PRESET +from torch.utils.data import DataLoader +import torch +import torch.backends.cuda +from huggingface_hub import HfApi, Repository, interpreter_login +from huggingface_hub.utils import HfFolder + +from toolkit.basic import value_map +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.ema import ExponentialMovingAverage +from toolkit.embedding import Embedding +from toolkit.image_utils import show_tensors, show_latents, reduce_contrast +from toolkit.ip_adapter import IPAdapter +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ + lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE +from toolkit.lycoris_special import LycorisSpecialNetwork +from toolkit.models.decorator import Decorator +from toolkit.network_mixins import Network +from toolkit.optimizer import get_optimizer +from toolkit.paths import CONFIG_ROOT +from toolkit.progress_bar import ToolkitProgressBar +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sampler import get_sampler +from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ + load_ip_adapter_model, load_custom_adapter_model + +from toolkit.scheduler import get_lr_scheduler +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.stable_diffusion_model import StableDiffusion + +from jobs.process import BaseTrainProcess +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \ + parse_metadata_from_safetensors +from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight +import gc + +from tqdm import tqdm + +from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \ + DecoratorConfig +from toolkit.logging import create_logger +from diffusers import FluxTransformer2DModel + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class BaseSDTrainProcess(BaseTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): + super().__init__(process_id, job, config) + self.sd: StableDiffusion + self.embedding: Union[Embedding, None] = None + + self.custom_pipeline = custom_pipeline + self.step_num = 0 + self.start_step = 0 + self.epoch_num = 0 + # start at 1 so we can do a sample at the start + self.grad_accumulation_step = 1 + # if true, then we do not do an optimizer step. We are accumulating gradients + self.is_grad_accumulation_step = False + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + network_config = self.get_conf('network', None) + if network_config is not None: + self.network_config = NetworkConfig(**network_config) + else: + self.network_config = None + self.train_config = TrainConfig(**self.get_conf('train', {})) + model_config = self.get_conf('model', {}) + + # update modelconfig dtype to match train + model_config['dtype'] = self.train_config.dtype + self.model_config = ModelConfig(**model_config) + + self.save_config = SaveConfig(**self.get_conf('save', {})) + self.sample_config = SampleConfig(**self.get_conf('sample', {})) + first_sample_config = self.get_conf('first_sample', None) + if first_sample_config is not None: + self.has_first_sample_requested = True + self.first_sample_config = SampleConfig(**first_sample_config) + else: + self.has_first_sample_requested = False + self.first_sample_config = self.sample_config + self.logging_config = LoggingConfig(**self.get_conf('logging', {})) + self.logger = create_logger(self.logging_config, config) + self.optimizer: torch.optim.Optimizer = None + self.lr_scheduler = None + self.data_loader: Union[DataLoader, None] = None + self.data_loader_reg: Union[DataLoader, None] = None + self.trigger_word = self.get_conf('trigger_word', None) + + self.guidance_config: Union[GuidanceConfig, None] = None + guidance_config_raw = self.get_conf('guidance', None) + if guidance_config_raw is not None: + self.guidance_config = GuidanceConfig(**guidance_config_raw) + + # store is all are cached. Allows us to not load vae if we don't need to + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.embed_config = None + embedding_raw = self.get_conf('embedding', None) + if embedding_raw is not None: + self.embed_config = EmbeddingConfig(**embedding_raw) + + self.decorator_config: DecoratorConfig = None + decorator_raw = self.get_conf('decorator', None) + if decorator_raw is not None: + if not self.model_config.is_flux: + raise ValueError("Decorators are only supported for Flux models currently") + self.decorator_config = DecoratorConfig(**decorator_raw) + + # t2i adapter + self.adapter_config = None + adapter_raw = self.get_conf('adapter', None) + if adapter_raw is not None: + self.adapter_config = AdapterConfig(**adapter_raw) + # sdxl adapters end in _xl. Only full_adapter_xl for now + if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): + self.adapter_config.adapter_type += '_xl' + + # to hold network if there is one + self.network: Union[Network, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None + self.embedding: Union[Embedding, None] = None + self.decorator: Union[Decorator, None] = None + + is_training_adapter = self.adapter_config is not None and self.adapter_config.train + + self.do_lorm = self.get_conf('do_lorm', False) + self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio') + self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25) + # 'ratio', 0.25) + + # get the device state preset based on what we are training + self.train_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_adapter=is_training_adapter, + train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, + train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder, + require_grads=False # we ensure them later + ) + + self.get_params_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_adapter=is_training_adapter, + train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, + train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder, + require_grads=True # We check for grads when getting params + ) + + # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) + self.is_fine_tuning = True + if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None: + self.is_fine_tuning = False + + self.named_lora = False + if self.embed_config is not None or is_training_adapter: + self.named_lora = True + self.snr_gos: Union[LearnableSNRGamma, None] = None + self.ema: ExponentialMovingAverage = None + + validate_configs(self.train_config, self.model_config, self.save_config) + + def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): + # override in subclass + return generate_image_config_list + + def sample(self, step=None, is_first=False): + flush() + sample_folder = os.path.join(self.save_root, 'samples') + gen_img_config_list = [] + + sample_config = self.first_sample_config if is_first else self.sample_config + start_seed = sample_config.seed + current_seed = start_seed + + test_image_paths = [] + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: + test_image_path_list = self.adapter_config.test_img_path.split(',') + test_image_path_list = [p.strip() for p in test_image_path_list] + test_image_path_list = [p for p in test_image_path_list if p != ''] + # divide up images so they are evenly distributed across prompts + for i in range(len(sample_config.prompts)): + test_image_paths.append(test_image_path_list[i % len(test_image_path_list)]) + + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + + filename = f"[time]_{step_num}_[count].{self.sample_config.ext}" + + output_path = os.path.join(sample_folder, filename) + + prompt = sample_config.prompts[i] + + # add embedding if there is one + # note: diffusers will automatically expand the trigger to the number of added tokens + # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, expand_token=True, add_if_not_present=False + ) + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, expand_token=True, add_if_not_present=False + ) + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, self.trigger_word, add_if_not_present=False + ) + + extra_args = {} + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: + extra_args['adapter_image_path'] = test_image_paths[i] + + gen_img_config_list.append(GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + width=sample_config.width, + height=sample_config.height, + negative_prompt=sample_config.neg, + seed=current_seed, + guidance_scale=sample_config.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_config.sample_steps, + network_multiplier=sample_config.network_multiplier, + output_path=output_path, + output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + refiner_start_at=sample_config.refiner_start_at, + extra_values=sample_config.extra_values, + logger=self.logger, + **extra_args + )) + + # post process + gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list) + + # if we have an ema, set it to validation mode + if self.ema is not None: + self.ema.eval() + + # send to be generated + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + + if self.ema is not None: + self.ema.train() + + def update_training_metadata(self): + o_dict = OrderedDict({ + "training_info": self.get_training_info() + }) + if self.model_config.is_v2: + o_dict['ss_v2'] = True + o_dict['ss_base_model_version'] = 'sd_2.1' + + elif self.model_config.is_xl: + o_dict['ss_base_model_version'] = 'sdxl_1.0' + else: + o_dict['ss_base_model_version'] = 'sd_1.5' + + o_dict = add_base_model_info_to_meta( + o_dict, + is_v2=self.model_config.is_v2, + is_xl=self.model_config.is_xl, + ) + o_dict['ss_output_name'] = self.job.name + + if self.trigger_word is not None: + # just so auto1111 will pick it up + o_dict['ss_tag_frequency'] = { + f"1_{self.trigger_word}": { + f"{self.trigger_word}": 1 + } + } + + self.add_meta(o_dict) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def clean_up_saves(self): + # remove old saves + # get latest saved step + latest_item = None + if os.path.exists(self.save_root): + # pattern is {job_name}_{zero_filled_step} for both files and directories + pattern = f"{self.job.name}_*" + items = glob.glob(os.path.join(self.save_root, pattern)) + # Separate files and directories + safetensors_files = [f for f in items if f.endswith('.safetensors')] + pt_files = [f for f in items if f.endswith('.pt')] + directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')] + embed_files = [] + # do embedding files + if self.embed_config is not None: + embed_pattern = f"{self.embed_config.trigger}_*" + embed_items = glob.glob(os.path.join(self.save_root, embed_pattern)) + # will end in safetensors or pt + embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')] + + # check for critic files + critic_pattern = f"CRITIC_{self.job.name}_*" + critic_items = glob.glob(os.path.join(self.save_root, critic_pattern)) + + # Sort the lists by creation time if they are not empty + if safetensors_files: + safetensors_files.sort(key=os.path.getctime) + if pt_files: + pt_files.sort(key=os.path.getctime) + if directories: + directories.sort(key=os.path.getctime) + if embed_files: + embed_files.sort(key=os.path.getctime) + if critic_items: + critic_items.sort(key=os.path.getctime) + + # Combine and sort the lists + combined_items = safetensors_files + directories + pt_files + combined_items.sort(key=os.path.getctime) + + # Use slicing with a check to avoid 'NoneType' error + safetensors_to_remove = safetensors_files[ + :-self.save_config.max_step_saves_to_keep] if safetensors_files else [] + pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else [] + directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else [] + embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else [] + critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else [] + + items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove + + # remove all but the latest max_step_saves_to_keep + # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] + + # remove duplicates + items_to_remove = list(dict.fromkeys(items_to_remove)) + + for item in items_to_remove: + self.print(f"Removing old save: {item}") + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + # see if a yaml file with same name exists + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + if combined_items: + latest_item = combined_items[-1] + return latest_item + + def post_save_hook(self, save_path): + # override in subclass + pass + + def save(self, step=None): + flush() + if self.ema is not None: + # always save params as ema + self.ema.eval() + + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + + save_meta = copy.deepcopy(self.meta) + # get extra meta + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + additional_save_meta = self.adapter.get_additional_save_metadata() + if additional_save_meta is not None: + for key, value in additional_save_meta.items(): + save_meta[key] = value + + # prepare meta + save_meta = get_meta_for_safetensors(save_meta, self.job.name) + if not self.is_fine_tuning: + if self.network is not None: + lora_name = self.job.name + if self.named_lora: + # add _lora to name + lora_name += '_LoRA' + + filename = f'{lora_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 + + # if we are doing embedding training as well, add that + embedding_dict = self.embedding.state_dict() if self.embedding else None + self.network.save_weights( + file_path, + dtype=get_torch_dtype(self.save_config.dtype), + metadata=save_meta, + extra_state_dict=embedding_dict + ) + self.network.multiplier = prev_multiplier + # if we have an embedding as well, pair it with the network + + # even if added to lora, still save the trigger version + if self.embedding is not None: + emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors' + emb_file_path = os.path.join(self.save_root, emb_filename) + # for combo, above will get it + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" + self.embedding.save(emb_file_path) + + if self.decorator is not None: + dec_filename = f'{self.job.name}{step_num}.safetensors' + dec_file_path = os.path.join(self.save_root, dec_filename) + decorator_state_dict = self.decorator.state_dict() + for key, value in decorator_state_dict.items(): + if isinstance(value, torch.Tensor): + decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype)) + save_file( + decorator_state_dict, + dec_file_path, + metadata=save_meta, + ) + + if self.adapter is not None and self.adapter_config.train: + adapter_name = self.job.name + if self.network_config is not None or self.embedding is not None: + # add _lora to name + if self.adapter_config.type == 't2i': + adapter_name += '_t2i' + elif self.adapter_config.type == 'control_net': + adapter_name += '_cn' + elif self.adapter_config.type == 'clip': + adapter_name += '_clip' + elif self.adapter_config.type.startswith('ip'): + adapter_name += '_ip' + else: + adapter_name += '_adapter' + + filename = f'{adapter_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + # save adapter + state_dict = self.adapter.state_dict() + if self.adapter_config.type == 't2i': + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) + elif self.adapter_config.type == 'control_net': + # save in diffusers format + name_or_path = file_path.replace('.safetensors', '') + # move it to the new dtype and cpu + orig_device = self.adapter.device + orig_dtype = self.adapter.dtype + self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype)) + self.adapter.save_pretrained( + name_or_path, + dtype=get_torch_dtype(self.save_config.dtype), + safe_serialization=True + ) + meta_path = os.path.join(name_or_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(self.meta, f) + # move it back + self.adapter = self.adapter.to(orig_device, dtype=orig_dtype) + else: + direct_save = False + if self.adapter_config.train_only_image_encoder: + direct_save = True + if self.adapter_config.type == 'redux': + direct_save = True + save_ip_adapter_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype), + direct_save=direct_save + ) + else: + if self.save_config.save_format == "diffusers": + # saving as a folder path + file_path = file_path.replace('.safetensors', '') + # convert it back to normal object + save_meta = parse_metadata_from_safetensors(save_meta) + + if self.sd.refiner_unet and self.train_config.train_refiner: + # save refiner + refiner_name = self.job.name + '_refiner' + filename = f'{refiner_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + self.sd.save_refiner( + file_path, + save_meta, + get_torch_dtype(self.save_config.dtype) + ) + if self.train_config.train_unet or self.train_config.train_text_encoder: + self.sd.save( + file_path, + save_meta, + get_torch_dtype(self.save_config.dtype) + ) + + # save learnable params as json if we have thim + if self.snr_gos: + json_data = { + 'offset_1': self.snr_gos.offset_1.item(), + 'offset_2': self.snr_gos.offset_2.item(), + 'scale': self.snr_gos.scale.item(), + 'gamma': self.snr_gos.gamma.item(), + } + path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json') + with open(path_to_save, 'w') as f: + json.dump(json_data, f, indent=4) + + # save optimizer + if self.optimizer is not None: + try: + filename = f'optimizer.pt' + file_path = os.path.join(self.save_root, filename) + state_dict = self.optimizer.state_dict() + torch.save(state_dict, file_path) + except Exception as e: + print(e) + print("Could not save optimizer") + + self.print(f"Saved to {file_path}") + self.clean_up_saves() + self.post_save_hook(file_path) + + if self.ema is not None: + self.ema.train() + flush() + + # Called before the model is loaded + def hook_before_model_load(self): + # override in subclass + pass + + def hook_after_model_load(self): + # override in subclass + pass + + def hook_add_extra_train_params(self, params): + # override in subclass + return params + + def hook_before_train_loop(self): + self.logger.start() + + def ensure_params_requires_grad(self, force=False): + if self.train_config.do_paramiter_swapping and not force: + # the optimizer will handle this if we are not forcing + return + for group in self.params: + for param in group['params']: + if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter + param.requires_grad_(True) + + def setup_ema(self): + if self.train_config.ema_config.use_ema: + # our params are in groups. We need them as a single iterable + params = [] + for group in self.optimizer.param_groups: + for param in group['params']: + params.append(param) + self.ema = ExponentialMovingAverage( + params, + decay=self.train_config.ema_config.ema_decay, + use_feedback=self.train_config.ema_config.use_feedback, + param_multiplier=self.train_config.ema_config.param_multiplier, + ) + + def before_dataset_load(self): + pass + + def get_params(self): + # you can extend this in subclass to get params + # otherwise params will be gathered through normal means + return None + + def hook_train_loop(self, batch): + # return loss + return 0.0 + + def get_latest_save_path(self, name=None, post=''): + if name == None: + name = self.job.name + # get latest saved step + latest_path = None + if os.path.exists(self.save_root): + # Define patterns for both files and directories + patterns = [ + f"{name}*{post}.safetensors", + f"{name}*{post}.pt", + f"{name}*{post}" + ] + # Search for both files and directories + paths = [] + for pattern in patterns: + paths.extend(glob.glob(os.path.join(self.save_root, pattern))) + + # Filter out non-existent paths and sort by creation time + if paths: + paths = [p for p in paths if os.path.exists(p)] + # remove false positives + if '_LoRA' not in name: + paths = [p for p in paths if '_LoRA' not in p] + if '_refiner' not in name: + paths = [p for p in paths if '_refiner' not in p] + if '_t2i' not in name: + paths = [p for p in paths if '_t2i' not in p] + if '_cn' not in name: + paths = [p for p in paths if '_cn' not in p] + + if len(paths) > 0: + latest_path = max(paths, key=os.path.getctime) + + return latest_path + + def load_training_state_from_metadata(self, path): + meta = None + # if path is folder, then it is diffusers + if os.path.isdir(path): + meta_path = os.path.join(path, 'aitk_meta.yaml') + # load it + if os.path.exists(meta_path): + with open(meta_path, 'r') as f: + meta = yaml.load(f, Loader=yaml.FullLoader) + else: + meta = load_metadata_from_safetensors(path) + # if 'training_info' in Orderdict keys + if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: + self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + + def load_weights(self, path): + if self.network is not None: + extra_weights = self.network.load_weights(path) + self.load_training_state_from_metadata(path) + return extra_weights + else: + print("load_weights not implemented for non-network models") + return None + + def apply_snr(self, seperated_loss, timesteps): + if self.train_config.learnable_snr_gos: + # add snr_gamma + seperated_loss = apply_learnable_snr_gos(seperated_loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001: + # add snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + return seperated_loss + + def load_lorm(self): + latest_save_path = self.get_latest_save_path() + if latest_save_path is not None: + # hacky way to reload weights for now + # todo, do this + state_dict = load_file(latest_save_path, device=self.device) + self.sd.unet.load_state_dict(state_dict) + + meta = load_metadata_from_safetensors(latest_save_path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + + # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) + # sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) + # schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) + # timesteps = timesteps.to(self.device_torch, ) + # + # # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # step_indices = [t for t in timesteps] + # + # sigma = sigmas[step_indices].flatten() + # while len(sigma.shape) < n_dim: + # sigma = sigma.unsqueeze(-1) + # return sigma + + def load_additional_training_modules(self, params): + # override in subclass + return params + + def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype) + schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device) + timesteps = timesteps.to(self.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def get_noise(self, latents, batch_size, dtype=torch.float32): + # get noise + noise = self.sd.get_latent_noise( + height=latents.shape[2], + width=latents.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, + dtype=noise.dtype) * 2 - 1 + + # multiply by shift amount + noise_shift *= self.train_config.random_noise_shift + + # add to noise + noise += noise_shift + + # standardize the noise + std = noise.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + noise = noise * normalizer + + return noise + + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): + with torch.no_grad(): + with self.timer('prepare_prompt'): + prompts = batch.get_caption_list() + is_reg_list = batch.get_is_reg_list() + + is_any_reg = any([is_reg for is_reg in is_reg_list]) + + do_double = self.train_config.short_and_long_captions and not is_any_reg + + if self.train_config.short_and_long_captions and do_double: + # dont do this with regs. No point + + # double batch and add short captions to the end + prompts = prompts + batch.get_caption_short_list() + is_reg_list = is_reg_list + is_reg_list + if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet: + prompts = prompts + prompts + is_reg_list = is_reg_list + is_reg_list + + conditioned_prompts = [] + + for prompt, is_reg in zip(prompts, is_reg_list): + + # make sure the embedding is in the prompts + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + # make sure trigger is in the prompts if not a regularization run + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, + trigger=self.trigger_word, + add_if_not_present=not is_reg, + ) + + if not is_reg and self.train_config.prompt_saturation_chance > 0.0: + # do random prompt saturation by expanding the prompt to hit at least 77 tokens + if random.random() < self.train_config.prompt_saturation_chance: + est_num_tokens = len(prompt.split(' ')) + if est_num_tokens < 77: + num_repeats = int(77 / est_num_tokens) + 1 + prompt = ', '.join([prompt] * num_repeats) + + + conditioned_prompts.append(prompt) + + with self.timer('prepare_latents'): + dtype = get_torch_dtype(self.train_config.dtype) + imgs = None + is_reg = any(batch.get_is_reg_list()) + if batch.tensor is not None: + imgs = batch.tensor + imgs = imgs.to(self.device_torch, dtype=dtype) + # dont adjust for regs. + if self.train_config.img_multiplier is not None and not is_reg: + # do it ad contrast + imgs = reduce_contrast(imgs, self.train_config.img_multiplier) + if batch.latents is not None: + latents = batch.latents.to(self.device_torch, dtype=dtype) + batch.latents = latents + else: + # normalize to + if self.train_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [0.0002, -0.1034, -0.1879] + target_std_list = [0.5436, 0.5116, 0.5033] + else: + target_mean_list = [-0.0739, -0.1597, -0.2380] + target_std_list = [0.5623, 0.5295, 0.5347] + # Mean: tensor([-0.0739, -0.1597, -0.2380]) + # Standard Deviation: tensor([0.5623, 0.5295, 0.5347]) + imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True) + imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True) + imgs = (imgs - imgs_channel_mean) / imgs_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + imgs = imgs * target_std + target_mean + batch.tensor = imgs + + # show_tensors(imgs, 'imgs') + + latents = self.sd.encode_images(imgs) + batch.latents = latents + + if self.train_config.standardize_latents: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164] + target_std_list = [0.8979, 0.7505, 0.9150, 0.7451] + else: + target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929] + target_std_list = [0.8560, 0.9629, 0.7778, 0.6719] + + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) + latents_channel_std = latents.std(dim=(2, 3), keepdim=True) + latents = (latents - latents_channel_mean) / latents_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + latents = latents * target_std + target_mean + batch.latents = latents + + # show_latents(latents, self.sd.vae, 'latents') + + + if batch.unconditional_tensor is not None and batch.unconditional_latents is None: + unconditional_imgs = batch.unconditional_tensor + unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) + unconditional_latents = self.sd.encode_images(unconditional_imgs) + batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier + + unaugmented_latents = None + if self.train_config.loss_target == 'differential_noise': + # we determine noise from the differential of the latents + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + + batch_size = len(batch.file_items) + min_noise_steps = self.train_config.min_denoising_steps + max_noise_steps = self.train_config.max_denoising_steps + if self.model_config.refiner_name_or_path is not None: + # if we are not training the unet, then we are only doing refiner and do not need to double up + if self.train_config.train_unet: + max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = True + else: + min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = False + + with self.timer('prepare_noise'): + num_train_timesteps = self.train_config.num_train_timesteps + + if self.train_config.noise_scheduler in ['custom_lcm']: + # we store this value on our custom one + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.train_timesteps, device=self.device_torch + ) + elif self.train_config.noise_scheduler in ['lcm']: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps + ) + elif self.train_config.noise_scheduler == 'flowmatch': + linear_timesteps = any([ + self.train_config.linear_timesteps, + self.train_config.linear_timesteps2, + self.train_config.timestep_type == 'linear', + ]) + self.sd.noise_scheduler.set_train_timesteps( + num_train_timesteps, + device=self.device_torch, + linear=linear_timesteps + ) + else: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch + ) + + content_or_style = self.train_config.content_or_style + if is_reg: + content_or_style = self.train_config.content_or_style_reg + + # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': + if content_or_style in ['style', 'content']: + # this is from diffusers training code + # Cubic sampling for favoring later or earlier timesteps + # For more details about why cubic sampling is used for content / structure, + # refer to section 3.4 of https://arxiv.org/abs/2302.08453 + + # for content / structure, it is best to favor earlier timesteps + # for style, it is best to favor later timesteps + + orig_timesteps = torch.rand((batch_size,), device=latents.device) + + if content_or_style == 'content': + timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps + elif content_or_style == 'style': + timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps + + timestep_indices = value_map( + timestep_indices, + 0, + self.train_config.num_train_timesteps - 1, + min_noise_steps, + max_noise_steps - 1 + ) + timestep_indices = timestep_indices.long().clamp( + min_noise_steps + 1, + max_noise_steps - 1 + ) + + elif content_or_style == 'balanced': + if min_noise_steps == max_noise_steps: + timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps + else: + # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + timestep_indices = torch.randint( + min_noise_steps + 1, + max_noise_steps - 1, + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + else: + raise ValueError(f"Unknown content_or_style {content_or_style}") + + # do flow matching + # if self.sd.is_flow_matching: + # u = compute_density_for_timestep_sampling( + # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] + # batch_size=batch_size, + # logit_mean=0.0, + # logit_std=1.0, + # mode_scale=1.29, + # ) + # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() + # convert the timestep_indices to a timestep + timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] + timesteps = torch.stack(timesteps, dim=0) + + # get noise + noise = self.get_noise(latents, batch_size, dtype=dtype) + + # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents + # this will negate any noise offsets + if self.train_config.dynamic_noise_offset and not is_reg: + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2 + # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel + noise = noise + latents_channel_mean + + if self.train_config.loss_target == 'differential_noise': + differential = latents - unaugmented_latents + # add noise to differential + # noise = noise + differential + noise = noise + (differential * 0.5) + # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max()) + latents = unaugmented_latents + + noise_multiplier = self.train_config.noise_multiplier + + noise = noise * noise_multiplier + + latent_multiplier = self.train_config.latent_multiplier + + # handle adaptive scaling mased on std + if self.train_config.adaptive_scaling_factor: + std = latents.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + latent_multiplier = normalizer + + latents = latents * latent_multiplier + batch.latents = latents + + # normalize latents to a mean of 0 and an std of 1 + # mean_zero_latents = latents - latents.mean() + # latents = mean_zero_latents / mean_zero_latents.std() + + if batch.unconditional_latents is not None: + batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier + + + noisy_latents = self.sd.add_noise(latents, noise, timesteps) + + # determine scaled noise + # todo do we need to scale this or does it always predict full intensity + # noise = noisy_latents - latents + + # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 + if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # add it to the batch + batch.sigmas = sigmas + # todo is this for sdxl? find out where this came from originally + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + + def double_up_tensor(tensor: torch.Tensor): + if tensor is None: + return None + return torch.cat([tensor, tensor], dim=0) + + if do_double: + if self.model_config.refiner_name_or_path: + # apply refiner double up + refiner_timesteps = torch.randint( + max_noise_steps, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) + refiner_timesteps = refiner_timesteps.long() + # add our new timesteps on to end + timesteps = torch.cat([timesteps, refiner_timesteps], dim=0) + + refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps) + noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0) + + else: + # just double it + noisy_latents = double_up_tensor(noisy_latents) + timesteps = double_up_tensor(timesteps) + + noise = double_up_tensor(noise) + # prompts are already updated above + imgs = double_up_tensor(imgs) + batch.mask_tensor = double_up_tensor(batch.mask_tensor) + batch.control_tensor = double_up_tensor(batch.control_tensor) + + noisy_latent_multiplier = self.train_config.noisy_latent_multiplier + + if noisy_latent_multiplier != 1.0: + noisy_latents = noisy_latents * noisy_latent_multiplier + + # remove grads for these + noisy_latents.requires_grad = False + noisy_latents = noisy_latents.detach() + noise.requires_grad = False + noise = noise.detach() + + return noisy_latents, noise, timesteps, conditioned_prompts, imgs + + def setup_adapter(self): + # t2i adapter + is_t2i = self.adapter_config.type == 't2i' + is_control_net = self.adapter_config.type == 'control_net' + if self.adapter_config.type == 't2i': + suffix = 't2i' + elif self.adapter_config.type == 'control_net': + suffix = 'cn' + elif self.adapter_config.type == 'clip': + suffix = 'clip' + elif self.adapter_config.type == 'reference': + suffix = 'ref' + elif self.adapter_config.type.startswith('ip'): + suffix = 'ip' + else: + suffix = 'adapter' + adapter_name = self.name + if self.network_config is not None: + adapter_name = f"{adapter_name}_{suffix}" + latest_save_path = self.get_latest_save_path(adapter_name) + + dtype = get_torch_dtype(self.train_config.dtype) + if is_t2i: + # if we do not have a last save path and we have a name_or_path, + # load from that + if latest_save_path is None and self.adapter_config.name_or_path is not None: + self.adapter = T2IAdapter.from_pretrained( + self.adapter_config.name_or_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + varient="fp16", + # use_safetensors=True, + ) + else: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + elif is_control_net: + if self.adapter_config.name_or_path is None: + raise ValueError("ControlNet requires a name_or_path to load from currently") + load_from_path = self.adapter_config.name_or_path + if latest_save_path is not None: + load_from_path = latest_save_path + self.adapter = ControlNetModel.from_pretrained( + load_from_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + ) + elif self.adapter_config.type == 'clip': + self.adapter = ClipVisionAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + elif self.adapter_config.type == 'reference': + self.adapter = ReferenceAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + elif self.adapter_config.type.startswith('ip'): + self.adapter = IPAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() + else: + self.adapter = CustomAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + self.adapter.to(self.device_torch, dtype=dtype) + if latest_save_path is not None and not is_control_net: + # load adapter from path + print(f"Loading adapter from {latest_save_path}") + if is_t2i: + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + elif self.adapter_config.type.startswith('ip'): + # ip adapter + loaded_state_dict = load_ip_adapter_model( + latest_save_path, + self.device, + dtype=dtype, + direct_load=self.adapter_config.train_only_image_encoder + ) + self.adapter.load_state_dict(loaded_state_dict) + else: + # custom adapter + loaded_state_dict = load_custom_adapter_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + if latest_save_path is not None and self.adapter_config.train: + self.load_training_state_from_metadata(latest_save_path) + # set trainable params + self.sd.adapter = self.adapter + + def run(self): + # torch.autograd.set_detect_anomaly(True) + # run base process run + BaseTrainProcess.run(self) + params = [] + + ### HOOK ### + self.hook_before_model_load() + model_config_to_load = copy.deepcopy(self.model_config) + + if self.is_fine_tuning: + # get the latest checkpoint + # check to see if we have a latest save + latest_save_path = self.get_latest_save_path() + + if latest_save_path is not None: + print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + model_config_to_load.name_or_path = latest_save_path + self.load_training_state_from_metadata(latest_save_path) + + # get the noise scheduler + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + 'sd' if not self.model_config.is_pixart else 'pixart' + ) + + if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: + previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') + if previous_refiner_save is not None: + model_config_to_load.refiner_name_or_path = previous_refiner_save + self.load_training_state_from_metadata(previous_refiner_save) + + self.sd = StableDiffusion( + device=self.device, + model_config=model_config_to_load, + dtype=self.train_config.dtype, + custom_pipeline=self.custom_pipeline, + noise_scheduler=sampler, + ) + # run base sd process run + self.sd.load_model() + + dtype = get_torch_dtype(self.train_config.dtype) + + # model is loaded from BaseSDProcess + unet = self.sd.unet + vae = self.sd.vae + tokenizer = self.sd.tokenizer + text_encoder = self.sd.text_encoder + noise_scheduler = self.sd.noise_scheduler + + if self.train_config.xformers: + vae.enable_xformers_memory_efficient_attention() + unet.enable_xformers_memory_efficient_attention() + if isinstance(text_encoder, list): + for te in text_encoder: + # if it has it + if hasattr(te, 'enable_xformers_memory_efficient_attention'): + te.enable_xformers_memory_efficient_attention() + if self.train_config.sdp: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + + # # check if we have sage and is flux + # if self.sd.is_flux: + # # try_to_activate_sage_attn() + # try: + # from sageattention import sageattn + # from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0 + # model: FluxTransformer2DModel = self.sd.unet + # # enable sage attention on each block + # for block in model.transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + # for block in model.single_transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + + # except ImportError: + # print("sage attention is not installed. Using SDP instead") + + if self.train_config.gradient_checkpointing: + if self.sd.is_flux: + unet.gradient_checkpointing = True + else: + unet.enable_gradient_checkpointing() + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'enable_gradient_checkpointing'): + te.enable_gradient_checkpointing() + if hasattr(te, "gradient_checkpointing_enable"): + te.gradient_checkpointing_enable() + else: + if hasattr(text_encoder, 'enable_gradient_checkpointing'): + text_encoder.enable_gradient_checkpointing() + if hasattr(text_encoder, "gradient_checkpointing_enable"): + text_encoder.gradient_checkpointing_enable() + + if self.sd.refiner_unet is not None: + self.sd.refiner_unet.to(self.device_torch, dtype=dtype) + self.sd.refiner_unet.requires_grad_(False) + self.sd.refiner_unet.eval() + if self.train_config.xformers: + self.sd.refiner_unet.enable_xformers_memory_efficient_attention() + if self.train_config.gradient_checkpointing: + self.sd.refiner_unet.enable_gradient_checkpointing() + + if isinstance(text_encoder, list): + for te in text_encoder: + te.requires_grad_(False) + te.eval() + else: + text_encoder.requires_grad_(False) + text_encoder.eval() + unet.to(self.device_torch, dtype=dtype) + unet.requires_grad_(False) + unet.eval() + vae = vae.to(torch.device('cpu'), dtype=dtype) + vae.requires_grad_(False) + vae.eval() + if self.train_config.learnable_snr_gos: + self.snr_gos = LearnableSNRGamma( + self.sd.noise_scheduler, device=self.device_torch + ) + # check to see if previous settings exist + path_to_load = os.path.join(self.save_root, 'learnable_snr.json') + if os.path.exists(path_to_load): + with open(path_to_load, 'r') as f: + json_data = json.load(f) + if 'offset' in json_data: + # legacy + self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch) + else: + self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch) + self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch) + self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) + self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + + self.hook_after_model_load() + flush() + if not self.is_fine_tuning: + if self.network_config is not None: + # TODO should we completely switch to LycorisSpecialNetwork? + network_kwargs = self.network_config.network_kwargs + is_lycoris = False + is_lorm = self.network_config.type.lower() == 'lorm' + # default to LoCON if there are any conv layers or if it is named + NetworkClass = LoRASpecialNetwork + if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': + NetworkClass = LycorisSpecialNetwork + is_lycoris = True + + if is_lorm: + network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains + network_kwargs['parameter_threshold'] = lorm_parameter_threshold + network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE + + # if is_lycoris: + # preset = PRESET['full'] + # NetworkClass.apply_preset(preset) + + self.network = NetworkClass( + text_encoder=text_encoder, + unet=unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + **network_kwargs + ) + + + # todo switch everything to proper mixed precision like this + self.network.force_to(self.device_torch, dtype=torch.float32) + # give network to sd so it can use it + self.sd.network = self.network + self.network._update_torch_multiplier() + + self.network.apply_to( + text_encoder, + unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + # we cannot merge in if quantized + if self.model_config.quantize: + # todo find a way around this + self.network.can_merge_in = False + + if is_lorm: + self.network.is_lorm = True + # make sure it is on the right device + self.sd.unet.to(self.sd.device, dtype=dtype) + original_unet_param_count = count_parameters(self.sd.unet) + self.network.setup_lorm() + new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction() + + print_lorm_extract_details( + start_num_params=original_unet_param_count, + end_num_params=new_unet_param_count, + num_replaced=len(self.network.get_all_modules()), + ) + + self.network.prepare_grad_etc(text_encoder, unet) + flush() + + # LyCORIS doesnt have default_lr + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.network.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.network.prepare_optimizer_params( + **config + ) + + params += params_net + + if self.train_config.gradient_checkpointing: + self.network.enable_gradient_checkpointing() + + lora_name = self.name + # need to adapt name so they are not mixed up + if self.named_lora: + lora_name = f"{lora_name}_LoRA" + + latest_save_path = self.get_latest_save_path(lora_name) + extra_weights = None + if latest_save_path is not None: + self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + self.print(f"Loading from {latest_save_path}") + extra_weights = self.load_weights(latest_save_path) + self.network.multiplier = 1.0 + + if self.embed_config is not None: + # we are doing embedding training as well + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config + ) + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + if self.embedding.step > 1: + self.step_num = self.embedding.step + self.start_step = self.step_num + + # self.step_num = self.embedding.step + # self.start_step = self.step_num + params.append({ + 'params': list(self.embedding.get_trainable_params()), + 'lr': self.train_config.embedding_lr + }) + + flush() + + if self.decorator_config is not None: + self.decorator = Decorator( + num_tokens=self.decorator_config.num_tokens, + token_size=4096 # t5xxl hidden size for flux + ) + latest_save_path = self.get_latest_save_path() + # load last saved weights + if latest_save_path is not None: + state_dict = load_file(latest_save_path) + self.decorator.load_state_dict(state_dict) + self.load_training_state_from_metadata(path) + + params.append({ + 'params': list(self.decorator.parameters()), + 'lr': self.train_config.lr + }) + + # give it to the sd network + self.sd.decorator = self.decorator + self.decorator.to(self.device_torch, dtype=torch.float32) + self.decorator.train() + + flush() + + if self.adapter_config is not None: + self.setup_adapter() + if self.adapter_config.train: + + if isinstance(self.adapter, IPAdapter): + # we have custom LR groups for IPAdapter + adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr) + for group in adapter_param_groups: + params.append(group) + else: + # set trainable params + params.append({ + 'params': list(self.adapter.parameters()), + 'lr': self.train_config.adapter_lr + }) + + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() + flush() + + params = self.load_additional_training_modules(params) + + else: # no network, embedding or adapter + # set the device state preset before getting params + self.sd.set_device_state(self.get_params_device_state_preset) + + # params = self.get_params() + if len(params) == 0: + # will only return savable weights and ones with grad + params = self.sd.prepare_optimizer_params( + unet=self.train_config.train_unet, + text_encoder=self.train_config.train_text_encoder, + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr, + refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None, + refiner_lr=self.train_config.refiner_lr, + ) + # we may be using it for prompt injections + if self.adapter_config is not None and self.adapter is None: + self.setup_adapter() + flush() + ### HOOK ### + params = self.hook_add_extra_train_params(params) + self.params = params + # self.params = [] + + # for param in params: + # if isinstance(param, dict): + # self.params += param['params'] + # else: + # self.params.append(param) + + if self.train_config.start_step is not None: + self.step_num = self.train_config.start_step + self.start_step = self.step_num + + optimizer_type = self.train_config.optimizer.lower() + + # esure params require grad + self.ensure_params_requires_grad(force=True) + optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, + optimizer_params=self.train_config.optimizer_params) + self.optimizer = optimizer + + # set it to do paramiter swapping + if self.train_config.do_paramiter_swapping: + # only works for adafactor, but it should have thrown an error prior to this otherwise + self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor) + + # check if it exists + optimizer_state_filename = f'optimizer.pt' + optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) + if os.path.exists(optimizer_state_file_path): + # try to load + # previous param groups + # previous_params = copy.deepcopy(optimizer.param_groups) + previous_lrs = [] + for group in optimizer.param_groups: + previous_lrs.append(group['lr']) + + try: + print(f"Loading optimizer state from {optimizer_state_file_path}") + optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() + except Exception as e: + print(f"Failed to load optimizer state from {optimizer_state_file_path}") + print(e) + + # update the optimizer LR from the params + print(f"Updating optimizer LR from params") + if len(previous_lrs) > 0: + for i, group in enumerate(optimizer.param_groups): + group['lr'] = previous_lrs[i] + group['initial_lr'] = previous_lrs[i] + + # Update the learning rates if they changed + # optimizer.param_groups = previous_params + + lr_scheduler_params = self.train_config.lr_scheduler_params + + # make sure it had bare minimum + if 'max_iterations' not in lr_scheduler_params: + lr_scheduler_params['total_iters'] = self.train_config.steps + + lr_scheduler = get_lr_scheduler( + self.train_config.lr_scheduler, + optimizer, + **lr_scheduler_params + ) + self.lr_scheduler = lr_scheduler + + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + if self.datasets_reg is not None: + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, + self.sd) + + flush() + ### HOOK ### + self.hook_before_train_loop() + + if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: + self.print("Generating first sample from first sample config") + self.sample(0, is_first=True) + + # sample first + if self.train_config.skip_first_sample or self.train_config.disable_sampling: + self.print("Skipping first sample due to config setting") + elif self.step_num <= 1 or self.train_config.force_first_sample: + self.print("Generating baseline samples before training") + self.sample(self.step_num) + + self.progress_bar = ToolkitProgressBar( + total=self.train_config.steps, + desc=self.job.name, + leave=True, + initial=self.step_num, + iterable=range(0, self.train_config.steps), + ) + self.progress_bar.pause() + + if self.data_loader is not None: + dataloader = self.data_loader + dataloader_iterator = iter(dataloader) + else: + dataloader = None + dataloader_iterator = None + + if self.data_loader_reg is not None: + dataloader_reg = self.data_loader_reg + dataloader_iterator_reg = iter(dataloader_reg) + else: + dataloader_reg = None + dataloader_iterator_reg = None + + # zero any gradients + optimizer.zero_grad() + + self.lr_scheduler.step(self.step_num) + + self.sd.set_device_state(self.train_device_state_preset) + flush() + # self.step_num = 0 + + # print(f"Compiling Model") + # torch.compile(self.sd.unet, dynamic=True) + + # make sure all params require grad + self.ensure_params_requires_grad(force=True) + + + ################################################################### + # TRAIN LOOP + ################################################################### + + + start_step_num = self.step_num + did_first_flush = False + for step in range(start_step_num, self.train_config.steps): + if self.train_config.do_paramiter_swapping: + self.optimizer.swap_paramiters() + self.timer.start('train_loop') + if self.train_config.do_random_cfg: + self.train_config.do_cfg = True + self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) + self.step_num = step + # default to true so various things can turn it off + self.is_grad_accumulation_step = True + if self.train_config.free_u: + self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) + self.progress_bar.unpause() + with torch.no_grad(): + # if is even step and we have a reg dataset, use that + # todo improve this logic to send one of each through if we can buckets and batch size might be an issue + is_reg_step = False + is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0 + is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0 + if self.train_config.disable_sampling: + is_sample_step = False + + batch_list = [] + + for b in range(self.train_config.gradient_accumulation): + # keep track to alternate on an accumulation step for reg + batch_step = step + # don't do a reg step on sample or save steps as we dont want to normalize on those + if batch_step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: + try: + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + except StopIteration: + with self.timer('reset_batch:reg'): + # hit the end of an epoch, reset + self.progress_bar.pause() + dataloader_iterator_reg = iter(dataloader_reg) + trigger_dataloader_setup_epoch(dataloader_reg) + + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + self.progress_bar.unpause() + is_reg_step = True + elif dataloader is not None: + try: + with self.timer('get_batch'): + batch = next(dataloader_iterator) + except StopIteration: + with self.timer('reset_batch'): + # hit the end of an epoch, reset + self.progress_bar.pause() + dataloader_iterator = iter(dataloader) + trigger_dataloader_setup_epoch(dataloader) + self.epoch_num += 1 + if self.train_config.gradient_accumulation_steps == -1: + # if we are accumulating for an entire epoch, trigger a step + self.is_grad_accumulation_step = False + self.grad_accumulation_step = 0 + with self.timer('get_batch'): + batch = next(dataloader_iterator) + self.progress_bar.unpause() + else: + batch = None + batch_list.append(batch) + batch_step += 1 + + # setup accumulation + if self.train_config.gradient_accumulation_steps == -1: + # epoch is handling the accumulation, dont touch it + pass + else: + # determine if we are accumulating or not + # since optimizer step happens in the loop, we trigger it a step early + # since we cannot reprocess it before them + optimizer_step_at = self.train_config.gradient_accumulation_steps + is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at + self.is_grad_accumulation_step = not is_optimizer_step + if is_optimizer_step: + self.grad_accumulation_step = 0 + + # flush() + ### HOOK ### + + loss_dict = self.hook_train_loop(batch_list) + self.timer.stop('train_loop') + if not did_first_flush: + flush() + did_first_flush = True + # flush() + # setup the networks to gradient checkpointing and everything works + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + with torch.no_grad(): + # torch.cuda.empty_cache() + # if optimizer has get_lrs method, then use it + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif hasattr(optimizer, 'get_learning_rates'): + learning_rate = optimizer.get_learning_rates()[0] + elif self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" + + self.progress_bar.set_postfix_str(prog_bar_string) + + # if the batch is a DataLoaderBatchDTO, then we need to clean it up + if isinstance(batch, DataLoaderBatchDTO): + with self.timer('batch_cleanup'): + batch.cleanup() + + # don't do on first step + if self.step_num != self.start_step: + if is_sample_step: + self.progress_bar.pause() + flush() + # print above the progress bar + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() + self.sample(self.step_num) + if self.train_config.unload_text_encoder: + # make sure the text encoder is unloaded + self.sd.text_encoder_to('cpu') + flush() + + self.ensure_params_requires_grad() + self.progress_bar.unpause() + + if is_save_step: + # print above the progress bar + self.progress_bar.pause() + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + self.ensure_params_requires_grad() + self.progress_bar.unpause() + + if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + self.progress_bar.pause() + with self.timer('log_to_tensorboard'): + # log to tensorboard + if self.writer is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + self.progress_bar.unpause() + + # log to logger + self.logger.log({ + 'learning_rate': learning_rate, + }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + elif self.logging_config.log_every is None: + # log every step + self.logger.log({ + 'learning_rate': learning_rate, + }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + + + if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0: + self.progress_bar.pause() + # print the timers and clear them + self.timer.print() + self.timer.reset() + self.progress_bar.unpause() + + # commit log + self.logger.commit(step=self.step_num) + + # sets progress bar to match out step + self.progress_bar.update(step - self.progress_bar.n) + + ############################# + # End of step + ############################# + + # update various steps + self.step_num = step + 1 + self.grad_accumulation_step += 1 + + + ################################################################### + ## END TRAIN LOOP + ################################################################### + + self.progress_bar.close() + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() + if not self.train_config.disable_sampling: + self.sample(self.step_num) + self.logger.commit(step=self.step_num) + print("") + self.save() + self.logger.finish() + + if self.save_config.push_to_hub: + if("HF_TOKEN" not in os.environ): + interpreter_login(new_session=False, write_permission=True) + self.push_to_hub( + repo_id=self.save_config.hf_repo_id, + private=self.save_config.hf_private + ) + del ( + self.sd, + unet, + noise_scheduler, + optimizer, + self.network, + tokenizer, + text_encoder, + ) + + flush() + + def push_to_hub( + self, + repo_id: str, + private: bool = False, + ): + readme_content = self._generate_readme(repo_id) + readme_path = os.path.join(self.save_root, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + api = HfApi() + + api.create_repo( + repo_id, + private=private, + exist_ok=True + ) + + api.upload_folder( + repo_id=repo_id, + folder_path=self.save_root, + ignore_patterns=["*.yaml", "*.pt"], + repo_type="model", + ) + + + def _generate_readme(self, repo_id: str) -> str: + """Generates the content of the README.md file.""" + + # Gather model info + base_model = self.model_config.name_or_path + instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None + if base_model == "black-forest-labs/FLUX.1-schnell": + license = "apache-2.0" + elif base_model == "black-forest-labs/FLUX.1-dev": + license = "other" + license_name = "flux-1-dev-non-commercial-license" + license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" + else: + license = "creativeml-openrail-m" + tags = [ + "text-to-image", + ] + if self.model_config.is_xl: + tags.append("stable-diffusion-xl") + if self.model_config.is_flux: + tags.append("flux") + if self.model_config.is_v3: + tags.append("sd3") + if self.network_config: + tags.extend( + [ + "lora", + "diffusers", + "template:sd-lora", + "ai-toolkit", + ] + ) + + # Generate the widget section + widgets = [] + sample_image_paths = [] + samples_dir = os.path.join(self.save_root, "samples") + if os.path.isdir(samples_dir): + for filename in os.listdir(samples_dir): + #The filenames are structured as 1724085406830__00000500_0.jpg + #So here we capture the 2nd part (steps) and 3rd (index the matches the prompt) + match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) + if match: + steps, index = int(match.group(1)), int(match.group(2)) + #Here we only care about uploading the latest samples, the match with the # of steps + if steps == self.train_config.steps: + sample_image_paths.append((index, f"samples/{filename}")) + + # Sort by numeric index + sample_image_paths.sort(key=lambda x: x[0]) + + # Create widgets matching prompt with the index + for i, prompt in enumerate(self.sample_config.prompts): + if i < len(sample_image_paths): + # Associate prompts with sample image paths based on the extracted index + _, image_path = sample_image_paths[i] + widgets.append( + { + "text": prompt, + "output": { + "url": image_path + }, + } + ) + dtype = "torch.bfloat16" if self.model_config.is_flux else "torch.float16" + # Construct the README content + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if os.path.isdir(samples_dir) else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model} +{"instance_prompt: " + instance_prompt if instance_prompt else ""} +license: {license} +{'license_name: ' + license_name if license == "other" else ""} +{'license_link: ' + license_link if license == "other" else ""} +--- + +# {self.job.name} +Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) + + +## Trigger words + +{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} + +## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. + +Weights for this model are available in Safetensors format. + +[Download](/{repo_id}/tree/main) them in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='{self.job.name}.safetensors') +image = pipeline('{instance_prompt if not widgets else self.sample_config.prompts[0]}').images[0] +image.save("my_image.png") +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +""" + return readme_content diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..d1885de23930ab372f26efa2b2281c9a9332f4fe --- /dev/null +++ b/jobs/process/BaseTrainProcess.py @@ -0,0 +1,79 @@ +import random +from datetime import datetime +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Union + +import torch +import yaml + +from jobs.process.BaseProcess import BaseProcess + +if TYPE_CHECKING: + from jobs import TrainJob, BaseJob, ExtensionJob + from torch.utils.tensorboard import SummaryWriter + from tqdm import tqdm + + +class BaseTrainProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.writer: 'SummaryWriter' + self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] + self.progress_bar: 'tqdm' = None + + self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None) + # if training seed is set, use it + if self.training_seed is not None: + torch.manual_seed(self.training_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.training_seed) + random.seed(self.training_seed) + + self.progress_bar = None + self.writer = None + self.training_folder = self.get_conf('training_folder', + self.job.training_folder if hasattr(self.job, 'training_folder') else None) + self.save_root = os.path.join(self.training_folder, self.name) + self.step = 0 + self.first_step = 0 + self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None) + self.setup_tensorboard() + self.save_training_config() + + def run(self): + super().run() + # implement in child class + # be sure to call super().run() first + pass + + # def print(self, message, **kwargs): + def print(self, *args): + if self.progress_bar is not None: + self.progress_bar.write(' '.join(map(str, args))) + self.progress_bar.update() + else: + print(*args) + + def setup_tensorboard(self): + if self.log_dir: + from torch.utils.tensorboard import SummaryWriter + now = datetime.now() + time_str = now.strftime('%Y%m%d-%H%M%S') + summary_name = f"{self.name}_{time_str}" + summary_dir = os.path.join(self.log_dir, summary_name) + self.writer = SummaryWriter(summary_dir) + + def save_training_config(self): + os.makedirs(self.save_root, exist_ok=True) + save_dif = os.path.join(self.save_root, f'config.yaml') + with open(save_dif, 'w') as f: + yaml.dump(self.job.raw_config, f) diff --git a/jobs/process/ExtractLoconProcess.py b/jobs/process/ExtractLoconProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dac5edd7bcc5fb959fb4a3717bfa975d1264cc --- /dev/null +++ b/jobs/process/ExtractLoconProcess.py @@ -0,0 +1,68 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + +mode_dict = { + 'fixed': { + 'linear': 64, + 'conv': 32, + 'type': int + }, + 'threshold': { + 'linear': 0, + 'conv': 0, + 'type': float + }, + 'ratio': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + }, + 'quantile': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + } +} + + +class ExtractLoconProcess(BaseExtractProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + self.disable_cp = self.get_conf('disable_cp', False) + + # set modes + if self.mode not in list(mode_dict.keys()): + raise ValueError(f"Unknown mode: {self.mode}") + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) + + def run(self): + super().run() + print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") + + state_dict, extract_diff_meta = extract_diff( + self.job.model_base, + self.job.model_extract, + self.mode, + self.linear_param, + self.conv_param, + self.job.device, + self.use_sparse_bias, + self.sparsity, + not self.disable_cp, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..76f0cc942b0c6d76139223851965e643dfb31376 --- /dev/null +++ b/jobs/process/ExtractLoraProcess.py @@ -0,0 +1,73 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + + +mode_dict = { + 'fixed': { + 'linear': 4, + 'conv': 0, + 'type': int + }, + 'threshold': { + 'linear': 0, + 'conv': 0, + 'type': float + }, + 'ratio': { + 'linear': 0.5, + 'conv': 0, + 'type': float + }, + 'quantile': { + 'linear': 0.5, + 'conv': 0, + 'type': float + } +} + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +class ExtractLoraProcess(BaseExtractProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + + # set modes + if self.mode not in list(mode_dict.keys()): + raise ValueError(f"Unknown mode: {self.mode}") + self.linear = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + + def run(self): + super().run() + print(f"Running process: {self.mode}, dim: {self.dim}") + + state_dict, extract_diff_meta = extract_diff( + self.job.model_base, + self.job.model_extract, + self.mode, + self.linear_param, + self.conv_param, + self.job.device, + self.use_sparse_bias, + self.sparsity, + small_conv=False, + linear_only=self.conv_param > 0.0000000001, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.dim}" + return super().get_output_path(prefix, suffix) diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..e0cb32d8e0d1cdc6bdc1f723bf73495fae14c809 --- /dev/null +++ b/jobs/process/GenerateProcess.py @@ -0,0 +1,146 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef, List, Optional, Union + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +import random + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + self.size_list: Union[List[int], None] = kwargs.get('size_list', None) + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.compile = kwargs.get('compile', False) + self.ext = kwargs.get('ext', 'png') + self.prompt_file = kwargs.get('prompt_file', False) + self.num_repeats = kwargs.get('num_repeats', 1) + self.prompts_in_file = self.prompts + if self.prompts is None: + raise ValueError("Prompts must be set") + if isinstance(self.prompts, str): + if os.path.exists(self.prompts): + with open(self.prompts, 'r', encoding='utf-8') as f: + self.prompts_in_file = f.read().splitlines() + self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0] + else: + raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") + + self.random_prompts = kwargs.get('random_prompts', False) + self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1) + self.max_images = kwargs.get('max_images', 10000) + + if self.random_prompts: + self.prompts = [] + for i in range(self.max_images): + num_prompts = random.randint(1, self.max_random_per_prompt) + prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)] + self.prompts.append(", ".join(prompt_list)) + else: + self.prompts = self.prompts_in_file + + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + + +class GenerateProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + sd: StableDiffusion + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.device = self.get_conf('device', self.job.device) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16')) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.model_config.dtype, + ) + + print(f"Using device {self.device}") + + def clean_prompt(self, prompt: str): + # remove any non alpha numeric characters or ,'" from prompt + return ''.join(e for e in prompt if e.isalnum() or e in ", '\"") + + def run(self): + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + self.sd.pipeline.to(self.device, self.torch_dtype) + + print("Compiling model...") + # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + if self.generate_config.compile: + self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") + + print(f"Generating {len(self.generate_config.prompts)} images") + # build prompt image configs + prompt_image_configs = [] + for _ in range(self.generate_config.num_repeats): + for prompt in self.generate_config.prompts: + width = self.generate_config.width + height = self.generate_config.height + # prompt = self.clean_prompt(prompt) + + if self.generate_config.size_list is not None: + # randomly select a size + width, height = random.choice(self.generate_config.size_list) + + prompt_image_configs.append(GenerateImageConfig( + prompt=prompt, + prompt_2=self.generate_config.prompt_2, + width=width, + height=height, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + negative_prompt=self.generate_config.neg, + negative_prompt_2=self.generate_config.neg_2, + seed=self.generate_config.seed, + guidance_rescale=self.generate_config.guidance_rescale, + output_ext=self.generate_config.ext, + output_folder=self.output_folder, + add_prompt_file=self.generate_config.prompt_file + )) + # generate images + self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) + + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/jobs/process/MergeLoconProcess.py b/jobs/process/MergeLoconProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..00c70cd2abdbc894f7b00c6cbf51a3dcfcc95531 --- /dev/null +++ b/jobs/process/MergeLoconProcess.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + + +class MergeLoconProcess(BaseExtractProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + new_state_dict = {} + raise NotImplementedError("This is not implemented yet") + + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/jobs/process/ModRescaleLoraProcess.py b/jobs/process/ModRescaleLoraProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb7436098f95ed774c7f1febc1b8bd7c0791981 --- /dev/null +++ b/jobs/process/ModRescaleLoraProcess.py @@ -0,0 +1,104 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.train_tools import get_torch_dtype + + +class ModRescaleLoraProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None + self.input_path = self.get_conf('input_path', required=True) + self.output_path = self.get_conf('output_path', required=True) + self.replace_meta = self.get_conf('replace_meta', default=False) + self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype) + self.current_weight = self.get_conf('current_weight', required=True, as_type=float) + self.target_weight = self.get_conf('target_weight', required=True, as_type=float) + self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down + self.is_xl = self.get_conf('is_xl', default=False, as_type=bool) + self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool) + + self.progress_bar = None + + def run(self): + super().run() + source_state_dict = load_file(self.input_path) + source_meta = load_metadata_from_safetensors(self.input_path) + + if self.replace_meta: + self.meta.update( + add_base_model_info_to_meta( + self.meta, + is_xl=self.is_xl, + is_v2=self.is_v2, + ) + ) + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + else: + save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + new_state_dict = OrderedDict() + + for key in list(source_state_dict.keys()): + v = source_state_dict[key] + v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32')) + + # all loras have an alpha, up weight and down weight + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", + # we can rescale by adjusting the alpha or the up weights, or the up and down weights + # I assume doing both up and down would be best all around, but I'm not sure + # some locons also have mid weights, we will leave those alone for now, will work without them + + # when adjusting alpha, it is used to calculate the multiplier in a lora module + # - scale = alpha / lora_dim + # - output = layer_out + lora_up_out * multiplier * scale + total_module_scale = torch.tensor(self.current_weight / self.target_weight) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + # only update alpha + if self.scale_target == 'alpha' and key.endswith('.alpha'): + v = v * total_module_scale + if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN + v = v * up_down_scale + v = v.detach().clone().to("cpu").to(self.save_dtype) + new_state_dict[key] = v + + save_meta = add_model_hash_to_meta(new_state_dict, save_meta) + save_file(new_state_dict, self.output_path, save_meta) + + # cleanup incase there are other jobs + del new_state_dict + del source_state_dict + del source_meta + + torch.cuda.empty_cache() + gc.collect() + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff3a69d89260396232f6085d161db3afe26b668 --- /dev/null +++ b/jobs/process/TrainESRGANProcess.py @@ -0,0 +1,657 @@ +import copy +import glob +import os +import time +from collections import OrderedDict +from typing import List, Optional + +from PIL import Image +from PIL.ImageOps import exif_transpose + +from toolkit.basic import flush +from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys +from safetensors.torch import save_file, load_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.data_loader import AugmentedImageDataset +from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.style import get_style_model_and_losses +from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL +from tqdm import tqdm +import time +import numpy as np +from .models.vgg19_critic import Critic + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + # transforms.Normalize([0.5], [0.5]), + ] +) + + +class TrainESRGANProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.model: ESRGAN = None + self.device = self.get_conf('device', self.job.device) + self.pretrained_path = self.get_conf('pretrained_path', 'None') + self.datasets_objects = self.get_conf('datasets', required=True) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) + self.sample_every = self.get_conf('sample_every', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) + self.save_every = self.get_conf('save_every', None) + self.upscale_sample = self.get_conf('upscale_sample', 4) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.zoom = self.get_conf('zoom', 4, as_type=int) + self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) + self.augmentations = self.get_conf('augmentations', {}) + self.torch_dtype = get_torch_dtype(self.dtype) + if self.torch_dtype == torch.bfloat16: + self.esrgan_dtype = torch.float32 + else: + self.esrgan_dtype = torch.float32 + + self.vgg_19 = None + self.style_weight_scalers = [] + self.content_weight_scalers = [] + + # throw error if zoom if not divisible by 2 + if self.zoom % 2 != 0: + raise ValueError('zoom must be divisible by 2') + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + self._pattern_loss = None + + # build augmentation transforms + aug_transforms = [] + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + ds['resolution'] = self.resolution + + if 'augmentations' not in ds: + ds['augmentations'] = self.augmentations + + # add the resize down augmentation + ds['augmentations'] = [{ + 'method': 'Resize', + 'params': { + 'width': int(self.resolution // self.zoom), + 'height': int(self.resolution // self.zoom), + # downscale interpolation, string will be evaluated + 'interpolation': 'cv2.INTER_AREA' + } + }] + ds['augmentations'] + + image_dataset = AugmentedImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=6 + ) + + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) + self.vgg_19.requires_grad_(False) + + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + # if is nan, set to 1 + if scaler != scaler: + scaler = 1 + print(f"Warning: content loss scaler is nan, setting to 1") + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def get_style_loss(self): + if self.style_weight > 0: + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_content_loss(self): + if self.content_weight > 0: + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss( + pattern_size=self.zoom, + dtype=self.torch_dtype + ).to(self.device, dtype=self.torch_dtype) + self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + # filename = f'{self.job.name}{step_num}.safetensors' + filename = f'{self.job.name}{step_num}.pth' + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # state_dict = self.model.state_dict() + + # state has the original state dict keys so we can save what we started from + save_state_dict = self.model.state_dict() + + for key in list(save_state_dict.keys()): + v = save_state_dict[key] + v = v.detach().clone().to("cpu").to(torch.float32) + save_state_dict[key] = v + + # most things wont use safetensors, save as torch + # save_file(save_state_dict, os.path.join(self.save_root, filename), save_meta) + torch.save(save_state_dict, os.path.join(self.save_root, filename)) + + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) + + def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + batch_sample_folder = os.path.join(self.save_root, 'samples_batch') + + batch_targets = None + batch_inputs = None + if batch is not None and not os.path.exists(batch_sample_folder): + os.makedirs(batch_sample_folder, exist_ok=True) + + self.model.eval() + + def process_and_save(img, target_img, save_path): + img = img.to(self.device, dtype=self.esrgan_dtype) + output = self.model(img) + # output = (output / 2 + 0.5).clamp(0, 1) + output = output.clamp(0, 1) + img = img.clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + output = Image.fromarray((output * 255).astype(np.uint8)) + img = Image.fromarray((img * 255).astype(np.uint8)) + + if isinstance(target_img, torch.Tensor): + # convert to pil + target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + target_img = Image.fromarray((target_img * 255).astype(np.uint8)) + + # upscale to size * self.upscale_sample while maintaining pixels + output = output.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + img = img.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + + width, height = output.size + + # stack input image and decoded image + target_image = target_img.resize((width, height)) + output = output.resize((width, height)) + img = img.resize((width, height)) + + output_img = Image.new('RGB', (width * 3, height)) + + output_img.paste(img, (0, 0)) + output_img.paste(output, (width, 0)) + output_img.paste(target_image, (width * 2, 0)) + + output_img.save(save_path) + + with torch.no_grad(): + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.resolution * self.zoom, self.resolution * self.zoom), resample=Image.BICUBIC) + + target_image = img + # downscale the image input + img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC) + + # downscale the image input + + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype) + img = img + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" + process_and_save(img, target_image, os.path.join(sample_folder, filename)) + + if batch is not None: + batch_targets = batch[0].detach() + batch_inputs = batch[1].detach() + batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0) + batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0) + + for i in range(len(batch_inputs)): + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" + process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) + + self.model.train() + + def load_model(self): + state_dict = None + path_to_load = self.pretrained_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors")) + files += glob.glob(os.path.join(self.save_root, f"{self.job.name}*.pth")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + elif self.pretrained_path is None: + self.print(f" - No checkpoint found, starting from scratch") + else: + self.print(f" - No checkpoint found, loading pretrained model") + self.print(f" - path: {path_to_load}") + + if path_to_load is not None: + self.print(f" - Loading pretrained checkpoint: {path_to_load}") + # if ends with pth then assume pytorch checkpoint + if path_to_load.endswith('.pth') or path_to_load.endswith('.pt'): + state_dict = torch.load(path_to_load, map_location=self.device) + elif path_to_load.endswith('.safetensors'): + state_dict_raw = load_file(path_to_load) + # make ordered dict as most things need it + state_dict = OrderedDict() + for key in esrgan_safetensors_keys: + state_dict[key] = state_dict_raw[key] + else: + raise Exception(f"Unknown file extension for checkpoint: {path_to_load}") + + # todo determine architecture from checkpoint + self.model = ESRGAN( + state_dict + ).to(self.device, dtype=self.esrgan_dtype) + + # set the model to training mode + self.model.train() + self.model.requires_grad_(True) + + def run(self): + super().run() + self.load_datasets() + steps_per_step = (self.critic.num_critic_per_gen + 1) + + max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num + self.first_step = start_step + + self.print(f"Training ESRGAN model:") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") + + # load model + self.load_model() + + params = self.model.parameters() + + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + if self.use_critic: + self.critic.setup() + + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + + # setup scheduler + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + verbose=False + ) + + # setup tqdm progress bar + self.progress_bar = tqdm( + total=num_steps, + desc='Training ESRGAN', + leave=True + ) + + blank_losses = OrderedDict({ + "total": [], + "style": [], + "content": [], + "mse": [], + "kl": [], + "tv": [], + "ptn": [], + "crD": [], + "crG": [], + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + print("Generating baseline samples") + self.sample(step=0) + # range start at self.epoch_num go to self.epochs + critic_losses = [] + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: + break + flush() + for targets, inputs in self.data_loader: + if self.step_num >= self.max_steps: + break + with torch.no_grad(): + is_critic_only_step = False + if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform(): + is_critic_only_step = True + + targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() + inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() + + optimizer.zero_grad() + # dont do grads here for critic step + do_grad = not is_critic_only_step + with torch.set_grad_enabled(do_grad): + pred = self.model(inputs) + + pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + if torch.isnan(pred).any(): + raise ValueError('pred has nan values') + if torch.isnan(targets).any(): + raise ValueError('targets has nan values') + + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + stacked = torch.cat([pred, targets], dim=0) + # stacked = (stacked / 2 + 0.5).clamp(0, 1) + stacked = stacked.clamp(0, 1) + self.vgg_19(stacked) + # make sure we dont have nans + if torch.isnan(self.vgg19_pool_4.tensor).any(): + raise ValueError('vgg19_pool_4 has nan values') + + if is_critic_only_step: + critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + critic_losses.append(critic_d_loss) + # don't do generator step + continue + else: + # doing a regular step + if len(critic_losses) == 0: + critic_d_loss = 0 + else: + critic_d_loss = sum(critic_losses) / len(critic_losses) + + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + + mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight + tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight + if self.use_critic: + critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + # make sure non nan + if torch.isnan(loss): + raise ValueError('loss is nan') + + # Backward pass and optimization + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"loss: {loss_value:.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" + + if self.optimizer_type.startswith('dadaptation') or self.optimizer_type.startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) + + log_losses["total"].append(loss_value) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) + + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: + # print above the progress bar + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num, batch=[targets, inputs]) + + if self.save_every and self.step_num % self.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.log_every and self.step_num % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + + self.step_num += 1 + # end epoch + if self.writer is not None: + eps = 1e-6 + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) + + self.save() diff --git a/jobs/process/TrainFineTuneProcess.py b/jobs/process/TrainFineTuneProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..a13a7cf640ad2a695d2f330a8cb4636985593376 --- /dev/null +++ b/jobs/process/TrainFineTuneProcess.py @@ -0,0 +1,13 @@ +from collections import OrderedDict +from jobs import TrainJob +from jobs.process import BaseTrainProcess + + +class TrainFineTuneProcess(BaseTrainProcess): + def __init__(self,process_id: int, job: TrainJob, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2dc3398a5c29edcaa386e48117644a395db677 --- /dev/null +++ b/jobs/process/TrainSDRescaleProcess.py @@ -0,0 +1,277 @@ +import glob +import os +from collections import OrderedDict +import random +from typing import Optional, List + +from safetensors.torch import save_file, load_file +from tqdm import tqdm + +from toolkit.layers import ReductionKernel +from toolkit.stable_diffusion_model import PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from leco import train_util, model_util +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class RescaleConfig: + def __init__( + self, + **kwargs + ): + self.from_resolution = kwargs.get('from_resolution', 512) + self.scale = kwargs.get('scale', 0.5) + self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None) + self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000) + self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) + self.prompt_dropout = kwargs.get('prompt_dropout', 0.1) + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class TrainSDRescaleProcess(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + # pass our custom pipeline to super so it sets it up + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True)) + self.reduce_size_fn = ReductionKernel( + in_channels=4, + kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution), + dtype=get_torch_dtype(self.train_config.dtype), + device=self.device_torch, + ) + + self.latent_paths: List[str] = [] + self.empty_embedding: PromptEmbeds = None + + def before_model_load(self): + pass + + def get_latent_tensors(self): + dtype = get_torch_dtype(self.train_config.dtype) + + num_to_generate = 0 + # check if dir exists + if not os.path.exists(self.rescale_config.latent_tensor_dir): + os.makedirs(self.rescale_config.latent_tensor_dir) + num_to_generate = self.rescale_config.num_latent_tensors + else: + # find existing + current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors")) + num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list) + self.latent_paths = current_tensor_list + + if num_to_generate > 0: + print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors") + + # unload other model + self.sd.unet.to('cpu') + + # load aux network + self.sd_parent = StableDiffusion( + self.device_torch, + model_config=self.model_config, + dtype=self.train_config.dtype, + ) + self.sd_parent.load_model() + self.sd_parent.unet.to(self.device_torch, dtype=dtype) + # we dont need text encoder for this + + del self.sd_parent.text_encoder + del self.sd_parent.tokenizer + + self.sd_parent.unet.eval() + self.sd_parent.unet.requires_grad_(False) + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) + torch.set_default_device(self.device_torch) + + for i in tqdm(range(num_to_generate)): + dtype = get_torch_dtype(self.train_config.dtype) + # get a random seed + seed = torch.randint(0, 2 ** 32, (1,)).item() + # zero pad seed string to max length + seed_string = str(seed).zfill(10) + # set seed + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # # ger a random number of steps + timesteps_to = self.train_config.max_denoising_steps + + # set the scheduler to the number of steps + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + + noise = self.sd.get_latent_noise( + pixel_height=self.rescale_config.from_resolution, + pixel_width=self.rescale_config.from_resolution, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + # get random guidance scale from 1.0 to 10.0 (CFG) + guidance_scale = torch.rand(1).item() * 9.0 + 1.0 + + # do a timestep of 1 + timestep = 1 + + noise_pred_target = self.sd_parent.predict_noise( + latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + + # build state dict + state_dict = OrderedDict() + state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16) + state_dict['latents'] = latents.to('cpu', dtype=torch.float16) + state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16) + state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16) + state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16) + state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow + + file_name = f"{seed_string}_{i}.safetensors" + file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name) + save_file(state_dict, file_path) + self.latent_paths.append(file_path) + + print("Removing parent model") + # delete parent + del self.sd_parent + flush() + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + self.sd.unet.to(self.device_torch, dtype=dtype) + + def hook_before_train_loop(self): + # encode our empty prompt + self.empty_embedding = self.sd.encode_prompt("") + self.empty_embedding = self.empty_embedding.to(self.device_torch, + dtype=get_torch_dtype(self.train_config.dtype)) + + # Move train model encoder to cpu + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to('cpu') + encoder.eval() + encoder.requires_grad_(False) + else: + self.sd.text_encoder.to('cpu') + self.sd.text_encoder.eval() + self.sd.text_encoder.requires_grad_(False) + + # self.sd.unet.to('cpu') + flush() + + self.get_latent_tensors() + + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + loss_function = torch.nn.MSELoss() + + # train it + # Begin gradient accumulation + self.sd.unet.train() + self.sd.unet.requires_grad_(True) + self.sd.unet.to(self.device_torch, dtype=dtype) + + with torch.no_grad(): + self.optimizer.zero_grad() + + # pick random latent tensor + latent_path = random.choice(self.latent_paths) + latent_tensor = load_file(latent_path) + + noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype) + latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype) + guidance_scale = (latent_tensor['guidance_scale']).item() + timestep = int((latent_tensor['timestep']).item()) + timesteps_to = int((latent_tensor['timesteps_to']).item()) + # seed = int((latent_tensor['seed']).item()) + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + + denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample + + # get the reduced latents + # reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) + denoised_target = self.reduce_size_fn(denoised_target.detach()) + reduced_latents = self.reduce_size_fn(latents.detach()) + + denoised_target.requires_grad = False + self.optimizer.zero_grad() + noise_pred_train = self.sd.predict_noise( + reduced_latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample + loss = loss_function(denoised_pred, denoised_target) + loss_float = loss.item() + loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + flush() + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + + return loss_dict + # end hook_train_loop diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..88b9d104e973e00481734182ef55b77767c4be88 --- /dev/null +++ b/jobs/process/TrainSliderProcess.py @@ -0,0 +1,694 @@ +import copy +import os +import random +from collections import OrderedDict +from typing import Union + +from PIL import Image +from diffusers import T2IAdapter +from torchvision.transforms import transforms +from tqdm import tqdm + +from toolkit.basic import value_map +from toolkit.config_modules import SliderConfig +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos +import gc +from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs + +import torch +from .BaseSDTrainProcess import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class TrainSliderProcess(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 + + # check if we have more targets than steps + # this can happen because of permutation son shuffling + if len(self.slider_config.targets) > self.train_config.steps: + # trim targets + self.slider_config.targets = self.slider_config.targets[:self.train_config.steps] + + # get presets + self.eval_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=False, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=False, + train_adapter=False, + train_embedding=False, + ) + + self.train_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=True, + train_adapter=False, + train_embedding=False, + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + + # read line by line from file + if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Found {len(self.prompt_txt_list)} prompts.") + + if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + cache = PromptEmbedsCache() + print(f"Building prompt cache") + + # get encoded latents for our prompts + with torch.no_grad(): + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # trim to max steps if max steps is lower than prompt count + # todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps + # prompts_to_cache = prompts_to_cache[:self.train_config.steps] + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) + + prompt_pairs = [] + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): + for target in self.slider_config.targets: + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_batch = [] + # we get the prompt pair multiplier from first prompt pair + # since they are all the same. We need to match their network polarity + prompt_pair_multipliers = prompt_pairs[0].multiplier_list + for prompt_multiplier in prompt_pair_multipliers: + # match the network multiplier polarity + anchor_scalar = 1.0 if prompt_multiplier > 0 else -1.0 + anchor_batch += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier * anchor_scalar + ) + ] + + anchor_pairs += [ + concat_anchors(anchor_batch).to('cpu') + ] + if len(anchor_pairs) > 0: + self.anchor_pairs = anchor_pairs + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + # self.anchor_pairs = anchor_pairs + flush() + if self.data_loader is not None: + # we will have images, prep the vae + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + # end hook_before_train_loop + + def before_dataset_load(self): + if self.slider_config.use_adapter == 'depth': + print(f"Loading T2I Adapter for depth") + # called before LoRA network is loaded but after model is loaded + # attach the adapter here so it is there before we load the network + adapter_path = 'TencentARC/t2iadapter_depth_sd15v2' + if self.model_config.is_xl: + adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0' + + print(f"Loading T2I Adapter from {adapter_path}") + + # dont name this adapter since we are not training it + self.t2i_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16" + ).to(self.device_torch) + self.t2i_adapter.eval() + self.t2i_adapter.requires_grad_(False) + flush() + + @torch.no_grad() + def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']): + + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + adapter_folder_path = self.slider_config.adapter_img_dir + adapter_images = [] + # loop through images + for file_item in batch.file_items: + img_path = file_item.path + file_name_no_ext = os.path.basename(img_path).split('.')[0] + # find the image + for ext in img_ext_list: + if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): + adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) + break + width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height + adapter_tensors = [] + # load images with torch transforms + for idx, adapter_image in enumerate(adapter_images): + # we need to centrally crop the largest dimension of the image to match the batch shape after scaling + # to the smallest dimension + img: Image.Image = Image.open(adapter_image) + if img.width > img.height: + # scale down so height is the same as batch + new_height = height + new_width = int(img.width * (height / img.height)) + else: + new_width = width + new_height = int(img.height * (width / img.width)) + + img = img.resize((new_width, new_height)) + crop_fn = transforms.CenterCrop((height, width)) + # crop the center to match batch + img = crop_fn(img) + img = adapter_transforms(img) + adapter_tensors.append(img) + + # stack them + adapter_tensors = torch.stack(adapter_tensors).to( + self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) + ) + return adapter_tensors + + def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]): + # set to eval mode + self.sd.set_device_state(self.eval_slider_device_state) + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) + + # get a random resolution + height, width = self.slider_config.resolutions[ + torch.randint(0, len(self.slider_config.resolutions), (1,)).item() + ] + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + loss_function = torch.nn.MSELoss() + + pred_kwargs = {} + + def get_noise_pred(neg, pos, gs, cts, dn): + down_kwargs = copy.deepcopy(pred_kwargs) + if 'down_block_additional_residuals' in down_kwargs: + dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0] + if dbr_batch_size != dn.shape[0]: + amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size) + down_kwargs['down_block_additional_residuals'] = [ + torch.cat([sample.clone()] * amount_to_add) for sample in + down_kwargs['down_block_additional_residuals'] + ] + return self.sd.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + neg, # negative prompt + pos, # positive prompt + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + **down_kwargs + ) + + with torch.no_grad(): + adapter_images = None + self.sd.unet.eval() + + # for a complete slider, the batch size is 4 to begin with now + true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size + from_batch = False + if batch is not None: + # traing from a batch of images, not generating ourselves + from_batch = True + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.slider_config.adapter_img_dir is not None: + adapter_images = self.get_adapter_images(batch) + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + + def rand_strength(sample): + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale + + down_block_additional_residuals = self.t2i_adapter(adapter_images) + down_block_additional_residuals = [ + rand_strength(sample) for sample in down_block_additional_residuals + ] + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0) + current_timestep = timesteps + else: + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps - 1, (1,) + ).item() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=true_batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + assert not self.network.is_active + self.sd.unet.eval() + # pass the multiplier list to the network + # double up since we are doing cfg + self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # unconditional + prompt_pair.target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + + noise_scheduler.set_timesteps(1000) + + current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + # split the latents into out prompt pair chunks + denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] + + # flush() # 4.2GB to 3GB on 512x512 + mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + has_mask = False + if batch and batch.mask_tensor is not None: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + has_mask = True + + if has_mask: + unmasked_target = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.target_class, # positive prompt + 1, + current_timestep, + denoised_latents + ) + unmasked_target = unmasked_target.detach() + unmasked_target.requires_grad = False + else: + unmasked_target = None + + # 4.20 GB RAM for 512x512 + positive_latents = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.negative_target, # positive prompt + 1, + current_timestep, + denoised_latents + ) + positive_latents = positive_latents.detach() + positive_latents.requires_grad = False + + neutral_latents = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.empty_prompt, # positive prompt (normally neutral + 1, + current_timestep, + denoised_latents + ) + neutral_latents = neutral_latents.detach() + neutral_latents.requires_grad = False + + unconditional_latents = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.positive_target, # positive prompt + 1, + current_timestep, + denoised_latents + ) + unconditional_latents = unconditional_latents.detach() + unconditional_latents.requires_grad = False + + denoised_latents = denoised_latents.detach() + + self.sd.set_device_state(self.train_slider_device_state) + self.sd.unet.train() + # start accumulating gradients + self.optimizer.zero_grad(set_to_none=True) + + anchor_loss_float = None + if len(self.anchor_pairs) > 0: + with torch.no_grad(): + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + anchor.to(self.device_torch, dtype=dtype) + + # first we get the target prediction without network active + anchor_target_noise = get_noise_pred( + anchor.neg_prompt, anchor.prompt, 1, current_timestep, denoised_latents + # ).to("cpu", dtype=torch.float32) + ).requires_grad_(False) + + # to save vram, we will run these through separately while tracking grads + # otherwise it consumes a ton of vram and this isn't our speed bottleneck + anchor_chunks = split_anchors(anchor, self.prompt_chunk_size) + anchor_target_noise_chunks = torch.chunk(anchor_target_noise, self.prompt_chunk_size, dim=0) + assert len(anchor_chunks) == len(denoised_latent_chunks) + + # 4.32 GB RAM for 512x512 + with self.network: + assert self.network.is_active + anchor_float_losses = [] + for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip( + anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks + ): + self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list + + anchor_pred_noise = get_noise_pred( + anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk + ) + # 9.42 GB RAM for 512x512 -> 4.20 GB RAM for 512x512 with new grad_checkpointing + anchor_loss = loss_function( + anchor_target_noise_chunk, + anchor_pred_noise, + ) + anchor_float_losses.append(anchor_loss.item()) + # compute anchor loss gradients + # we will accumulate them later + # this saves a ton of memory doing them separately + anchor_loss.backward() + del anchor_pred_noise + del anchor_target_noise_chunk + del anchor_loss + flush() + + anchor_loss_float = sum(anchor_float_losses) / len(anchor_float_losses) + del anchor_chunks + del anchor_target_noise_chunks + del anchor_target_noise + # move anchor back to cpu + anchor.to("cpu") + + with torch.no_grad(): + if self.slider_config.low_ram: + prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size) + denoised_latent_chunks = denoised_latent_chunks # just to have it in one place + positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk( + unconditional_latents.detach(), + self.prompt_chunk_size, + dim=0 + ) + mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0) + if unmasked_target is not None: + unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0) + else: + unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)] + else: + # run through in one instance + prompt_pair_chunks = [prompt_pair.detach()] + denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()] + positive_latents_chunks = [positive_latents.detach()] + neutral_latents_chunks = [neutral_latents.detach()] + unconditional_latents_chunks = [unconditional_latents.detach()] + mask_multiplier_chunks = [mask_multiplier] + unmasked_target_chunks = [unmasked_target] + + # flush() + assert len(prompt_pair_chunks) == len(denoised_latent_chunks) + # 3.28 GB RAM for 512x512 + with self.network: + assert self.network.is_active + loss_list = [] + for prompt_pair_chunk, \ + denoised_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk, \ + mask_multiplier_chunk, \ + unmasked_target_chunk \ + in zip( + prompt_pair_chunks, + denoised_latent_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + mask_multiplier_chunks, + unmasked_target_chunks + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list + target_latents = get_noise_pred( + prompt_pair_chunk.positive_target, + prompt_pair_chunk.target_class, + 1, + current_timestep, + denoised_latent_chunk + ) + + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) + + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] + + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier + + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset + offset_neutral = offset_neutral.detach().requires_grad_(False) + + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + + # do inverted mask to preserve non masked + if has_mask and unmasked_target_chunk is not None: + loss = loss * mask_multiplier_chunk + # match the mask unmasked_target_chunk + mask_target_loss = torch.nn.functional.mse_loss( + target_latents.float(), + unmasked_target_chunk.float(), + reduction="none" + ) + mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk) + loss += mask_target_loss + + loss = loss.mean([1, 2, 3]) + + if self.train_config.learnable_snr_gos: + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + + + loss = loss.mean() * prompt_pair_chunk.weight + + loss.backward() + loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + # flush() + + optimizer.step() + lr_scheduler.step() + + loss_float = sum(loss_list) / len(loss_list) + if anchor_loss_float is not None: + loss_float += anchor_loss_float + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + # latents + ) + # move back to cpu + prompt_pair.to("cpu") + # flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss_float is not None: + loss_dict['sl_l'] = loss_float + loss_dict['an_l'] = anchor_loss_float + + return loss_dict + # end hook_train_loop diff --git a/jobs/process/TrainSliderProcessOld.py b/jobs/process/TrainSliderProcessOld.py new file mode 100644 index 0000000000000000000000000000000000000000..8c25393a3503a99532676b153794fbe89c609160 --- /dev/null +++ b/jobs/process/TrainSliderProcessOld.py @@ -0,0 +1,408 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os +from typing import Optional + +from toolkit.config_modules import SliderConfig +from toolkit.paths import REPOS_ROOT +import sys + +from toolkit.stable_diffusion_model import PromptEmbeds + +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from leco import train_util, model_util +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class EncodedPromptPair: + def __init__( + self, + target_class, + positive, + negative, + neutral, + width=512, + height=512, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=1.0, + weight=1.0 + ): + self.target_class = target_class + self.positive = positive + self.negative = negative + self.neutral = neutral + self.width = width + self.height = height + self.action: int = action + self.multiplier = multiplier + self.weight = weight + + +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0 + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + +class TrainSliderProcessOld(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + cache = PromptEmbedsCache() + prompt_pairs: list[EncodedPromptPair] = [] + + # get encoded latents for our prompts + with torch.no_grad(): + neutral = "" + for target in self.slider_config.targets: + # build the cache + for prompt in [ + target.target_class, + target.positive, + target.negative, + neutral # empty neutral + ]: + if cache[prompt] is None: + cache[prompt] = self.sd.encode_prompt(prompt) + for resolution in self.slider_config.resolutions: + width, height = resolution + only_erase = len(target.positive.strip()) == 0 + only_enhance = len(target.negative.strip()) == 0 + + both = not only_erase and not only_enhance + + if only_erase and only_enhance: + raise ValueError("target must have at least one of positive or negative or both") + # for slider we need to have an enhancer, an eraser, and then + # an inverse with negative weights to balance the network + # if we don't do this, we will get different contrast and focus. + # we only perform actions of enhancing and erasing on the negative + # todo work on way to do all of this in one shot + + if both or only_erase: + prompt_pairs += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both or only_enhance: + prompt_pairs += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both: + prompt_pairs += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + prompt_pairs += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_pairs += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier + ) + ] + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + self.anchor_pairs = anchor_pairs + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + + height = prompt_pair.height + width = prompt_pair.width + target_class = prompt_pair.target_class + neutral = prompt_pair.neutral + negative = prompt_pair.negative + positive = prompt_pair.positive + weight = prompt_pair.weight + multiplier = prompt_pair.multiplier + + unet = self.sd.unet + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + + def get_noise_pred(p, n, gs, cts, dn): + return self.sd.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # unconditional + n, # positive + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + ) + + # set network multiplier + self.network.multiplier = multiplier + + with torch.no_grad(): + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + self.optimizer.zero_grad() + + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + with self.network: + assert self.network.is_active + self.network.multiplier = multiplier + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + positive, # unconditional + target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + ] + + positive_latents = get_noise_pred( + positive, negative, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + neutral_latents = get_noise_pred( + positive, neutral, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + unconditional_latents = get_noise_pred( + positive, positive, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + anchor_loss = None + if len(self.anchor_pairs) > 0: + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + with torch.no_grad(): + anchor_target_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + with self.network: + # anchor whatever weight prompt pair is using + pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 + self.network.multiplier = anchor.multiplier * pos_nem_mult + + anchor_pred_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + self.network.multiplier = prompt_pair.multiplier + + with self.network: + self.network.multiplier = prompt_pair.multiplier + target_latents = get_noise_pred( + positive, target_class, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + # if self.logging_config.verbose: + # self.print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + if len(self.anchor_pairs) > 0: + anchor_target_noise.requires_grad = False + anchor_loss = loss_function( + anchor_target_noise, + anchor_pred_noise, + ) + erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents - unconditional_latents) + + offset_neutral = neutral_latents + if erase: + offset_neutral -= offset + else: + # enhance + offset_neutral += offset + + loss = loss_function( + target_latents, + offset_neutral, + ) * weight + + loss_slide = loss.item() + + if anchor_loss is not None: + loss += anchor_loss + + loss_float = loss.item() + + loss = loss.to(self.device_torch) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss is not None: + loss_dict['sl_l'] = loss_slide + loss_dict['an_l'] = anchor_loss.item() + + return loss_dict + # end hook_train_loop diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6536cdbc62d23104a465623bb7745bd4df00b5 --- /dev/null +++ b/jobs/process/TrainVAEProcess.py @@ -0,0 +1,612 @@ +import copy +import glob +import os +import shutil +import time +from collections import OrderedDict + +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import save_file, load_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.image_utils import show_tensors +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm +from toolkit.data_loader import ImageDataset +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.style import get_style_model_and_losses +from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL +from tqdm import tqdm +import time +import numpy as np +from .models.vgg19_critic import Critic +from torchvision.transforms import Resize +import lpips + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def unnormalize(tensor): + return (tensor / 2 + 0.5).clamp(0, 1) + + +class TrainVAEProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.vae = None + self.device = self.get_conf('device', self.job.device) + self.vae_path = self.get_conf('vae_path', required=True) + self.datasets_objects = self.get_conf('datasets', required=True) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) + self.sample_every = self.get_conf('sample_every', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) + self.save_every = self.get_conf('save_every', None) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) + + self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) + self.torch_dtype = get_torch_dtype(self.dtype) + self.vgg_19 = None + self.style_weight_scalers = [] + self.content_weight_scalers = [] + self.lpips_loss:lpips.LPIPS = None + + self.vae_scale_factor = 8 + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + self._pattern_loss = None + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + ds['resolution'] = self.resolution + image_dataset = ImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=6 + ) + + def remove_oldest_checkpoint(self): + max_to_keep = 4 + folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) + if len(folders) > max_to_keep: + folders.sort(key=os.path.getmtime) + for folder in folders[:-max_to_keep]: + print(f"Removing {folder}") + shutil.rmtree(folder) + + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) + self.vgg_19.requires_grad_(False) + + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def get_style_loss(self): + if self.style_weight > 0: + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_content_loss(self): + if self.content_weight > 0: + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_kld_loss(self, mu, log_var): + if self.kld_weight > 0: + # Kullback-Leibler divergence + # added here for full training (not implemented). Not needed for only decoder + # as we are not changing the distribution of the latent space + # normally it would help keep a normal distribution for latents + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence + return KLD + else: + return torch.tensor(0.0, device=self.device) + + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device, + dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}_diffusers' + + self.vae = self.vae.to("cpu", dtype=torch.float16) + self.vae.save_pretrained( + save_directory=os.path.join(self.save_root, filename) + ) + self.vae = self.vae.to(self.device, dtype=self.torch_dtype) + + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) + + self.remove_oldest_checkpoint() + + def sample(self, step=None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + + with torch.no_grad(): + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.resolution, self.resolution)) + + input_img = img + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) + img = img + decoded = self.vae(img).sample + decoded = (decoded / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + decoded = Image.fromarray((decoded * 255).astype(np.uint8)) + + # stack input image and decoded image + input_img = input_img.resize((self.resolution, self.resolution)) + decoded = decoded.resize((self.resolution, self.resolution)) + + output_img = Image.new('RGB', (self.resolution * 2, self.resolution)) + output_img.paste(input_img, (0, 0)) + output_img.paste(decoded, (self.resolution, 0)) + + scale_up = 2 + if output_img.height <= 300: + scale_up = 4 + + # scale up using nearest neighbor + output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST) + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + output_img.save(os.path.join(sample_folder, filename)) + + def load_vae(self): + path_to_load = self.vae_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + else: + self.print(f" - No checkpoint found, starting from scratch") + # load vae + self.print(f"Loading VAE") + self.print(f" - Loading VAE: {path_to_load}") + if self.vae is None: + self.vae = AutoencoderKL.from_pretrained(path_to_load) + + # set decoder to train + self.vae.to(self.device, dtype=self.torch_dtype) + self.vae.requires_grad_(False) + self.vae.eval() + self.vae.decoder.train() + self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + def run(self): + super().run() + self.load_datasets() + + max_step_epochs = self.max_steps // len(self.data_loader) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num + self.first_step = start_step + + self.print(f"Training VAE") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") + + # load vae + self.load_vae() + + params = [] + + # only set last 2 layers to trainable + for param in self.vae.decoder.parameters(): + param.requires_grad = False + + train_all = 'all' in self.blocks_to_train + + if train_all: + params = list(self.vae.decoder.parameters()) + self.vae.decoder.requires_grad_(True) + else: + # mid_block + if train_all or 'mid_block' in self.blocks_to_train: + params += list(self.vae.decoder.mid_block.parameters()) + self.vae.decoder.mid_block.requires_grad_(True) + # up_blocks + if train_all or 'up_blocks' in self.blocks_to_train: + params += list(self.vae.decoder.up_blocks.parameters()) + self.vae.decoder.up_blocks.requires_grad_(True) + # conv_out (single conv layer output) + if train_all or 'conv_out' in self.blocks_to_train: + params += list(self.vae.decoder.conv_out.parameters()) + self.vae.decoder.conv_out.requires_grad_(True) + + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + if self.use_critic: + self.critic.setup() + + if self.lpips_weight > 0 and self.lpips_loss is None: + # self.lpips_loss = lpips.LPIPS(net='vgg') + self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype) + + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + + # setup scheduler + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + verbose=False + ) + + # setup tqdm progress bar + self.progress_bar = tqdm( + total=num_steps, + desc='Training VAE', + leave=True + ) + + # sample first + self.sample() + blank_losses = OrderedDict({ + "total": [], + "lpips": [], + "style": [], + "content": [], + "mse": [], + "kl": [], + "tv": [], + "ptn": [], + "crD": [], + "crG": [], + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + # range start at self.epoch_num go to self.epochs + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: + break + for batch in self.data_loader: + if self.step_num >= self.max_steps: + break + with torch.no_grad(): + + batch = batch.to(self.device, dtype=self.torch_dtype) + + # resize so it matches size of vae evenly + if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: + batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor, + batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) + + # forward pass + dgd = self.vae.encode(batch).latent_dist + mu, logvar = dgd.mean, dgd.logvar + latents = dgd.sample() + latents.detach().requires_grad_(True) + + pred = self.vae.decode(latents).sample + + with torch.no_grad(): + show_tensors( + pred.clamp(-1, 1).clone(), + "combined tensor" + ) + + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + stacked = torch.cat([pred, batch], dim=0) + stacked = (stacked / 2 + 0.5).clamp(0, 1) + self.vgg_19(stacked) + + if self.use_critic: + critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + else: + critic_d_loss = 0.0 + + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight + mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight + if self.lpips_weight > 0: + lpips_loss = self.lpips_loss( + pred.clamp(-1, 1), + batch.clamp(-1, 1) + ).mean() * self.lpips_weight + else: + lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight + if self.use_critic: + critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + + # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it + if self.lpips_weight > 0: + max_target = lpips_loss.abs() * 0.1 + with torch.no_grad(): + crit_g_scaler = 1.0 + if critic_gen_loss.abs() > max_target: + crit_g_scaler = max_target / critic_gen_loss.abs() + + critic_gen_loss *= crit_g_scaler + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"loss: {loss_value:.2e}" + if self.lpips_weight > 0: + loss_string += f" lpips: {lpips_loss.item():.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.kld_weight > 0: + loss_string += f" kld: {kld_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" + + if self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["lpips"].append(lpips_loss.item()) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["kl"].append(kld_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) + + log_losses["total"].append(loss_value) + log_losses["lpips"].append(lpips_loss.item()) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["kl"].append(kld_loss.item()) + log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) + + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: + # print above the progress bar + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num) + + if self.save_every and self.step_num % self.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.log_every and self.step_num % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + + self.step_num += 1 + # end epoch + if self.writer is not None: + eps = 1e-6 + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) + + self.save() diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387be08853c3bcb5cbc551d0b1c1c99dad124df6 --- /dev/null +++ b/jobs/process/__init__.py @@ -0,0 +1,15 @@ +from .BaseExtractProcess import BaseExtractProcess +from .ExtractLoconProcess import ExtractLoconProcess +from .ExtractLoraProcess import ExtractLoraProcess +from .BaseProcess import BaseProcess +from .BaseTrainProcess import BaseTrainProcess +from .TrainVAEProcess import TrainVAEProcess +from .BaseMergeProcess import BaseMergeProcess +from .TrainSliderProcess import TrainSliderProcess +from .TrainSliderProcessOld import TrainSliderProcessOld +from .TrainSDRescaleProcess import TrainSDRescaleProcess +from .ModRescaleLoraProcess import ModRescaleLoraProcess +from .GenerateProcess import GenerateProcess +from .BaseExtensionProcess import BaseExtensionProcess +from .TrainESRGANProcess import TrainESRGANProcess +from .BaseSDTrainProcess import BaseSDTrainProcess diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf438bf11487d3daff86266f515f477dcaf88cd --- /dev/null +++ b/jobs/process/models/vgg19_critic.py @@ -0,0 +1,194 @@ +import glob +import os + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + +from typing import TYPE_CHECKING, Union + + +class MeanReduce(nn.Module): + def __init__(self): + super(MeanReduce, self).__init__() + + def forward(self, inputs): + return torch.mean(inputs, dim=(1, 2, 3), keepdim=True) + + +class Vgg19Critic(nn.Module): + def __init__(self): + # vgg19 input (bs, 3, 512, 512) + # pool1 (bs, 64, 256, 256) + # pool2 (bs, 128, 128, 128) + # pool3 (bs, 256, 64, 64) + # pool4 (bs, 512, 32, 32) <- take this input + + super(Vgg19Critic, self).__init__() + self.main = nn.Sequential( + # input (bs, 512, 32, 32) + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2), # (bs, 512, 16, 16) + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2), # (bs, 512, 8, 8) + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + # (bs, 1, 4, 4) + MeanReduce(), # (bs, 1, 1, 1) + nn.Flatten(), # (bs, 1) + + # nn.Flatten(), # (128*8*8) = 8192 + # nn.Linear(128 * 8 * 8, 1) + ) + + def forward(self, inputs): + return self.main(inputs) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(f" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors") + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + # we need a warmup when we come on of 1000 steps + # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + # set model to not train for generator loss + self.model.eval() + self.model.requires_grad_(False) + vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + + # run model + stacked_output = self.model(vgg_pred) + + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + + # train critic here + self.model.train() + self.model.requires_grad_(True) + self.optimizer.zero_grad() + + critic_losses = [] + inputs = vgg_output.detach() + inputs = inputs.to(self.device, dtype=self.torch_dtype) + self.optimizer.zero_grad() + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # Compute gradient penalty + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # Compute WGAN-GP critic loss + critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + # avg loss + loss = np.mean(critic_losses) + return loss + + def get_lr(self): + if self.optimizer_type.startswith('dadaptation'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + + return learning_rate + diff --git a/notebooks/FLUX_1_dev_LoRA_Training.ipynb b/notebooks/FLUX_1_dev_LoRA_Training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8cfcd1fedfc941ac1a050f39499f77d303e23783 --- /dev/null +++ b/notebooks/FLUX_1_dev_LoRA_Training.ipynb @@ -0,0 +1,291 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "zl-S0m3pkQC5" + }, + "source": [ + "# AI Toolkit by Ostris\n", + "## FLUX.1-dev Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BvAG0GKAh59G" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/ostris/ai-toolkit\n", + "!mkdir -p /content/dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UFUW4ZMmnp1V" + }, + "source": [ + "Put your image dataset in the `/content/dataset` folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OV0HnOI6o8V6" + }, + "source": [ + "## Model License\n", + "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n", + "\n", + "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n", + "\n", + "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3yZZdhFRoj2m" + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# Prompt for the token\n", + "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n", + "\n", + "# Set the environment variable\n", + "os.environ['HF_TOKEN'] = hf_token\n", + "\n", + "print(\"HF_TOKEN environment variable has been set.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9gO2EzQ1kQC8" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image\n", + "import os\n", + "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8UUFzVRigbC" + }, + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_t28QURYjRQO" + }, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict([\n", + " ('job', 'extension'),\n", + " ('config', OrderedDict([\n", + " # this name will be the folder and filename name\n", + " ('name', 'my_first_flux_lora_v1'),\n", + " ('process', [\n", + " OrderedDict([\n", + " ('type', 'sd_trainer'),\n", + " # root folder to save training sessions/samples/weights\n", + " ('training_folder', '/content/output'),\n", + " # uncomment to see performance stats in the terminal every N steps\n", + " #('performance_log_every', 1000),\n", + " ('device', 'cuda:0'),\n", + " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n", + " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n", + " # ('trigger_word', 'image'),\n", + " ('network', OrderedDict([\n", + " ('type', 'lora'),\n", + " ('linear', 16),\n", + " ('linear_alpha', 16)\n", + " ])),\n", + " ('save', OrderedDict([\n", + " ('dtype', 'float16'), # precision to save\n", + " ('save_every', 250), # save every this many steps\n", + " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n", + " ])),\n", + " ('datasets', [\n", + " # datasets are a folder of images. captions need to be txt files with the same name as the image\n", + " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n", + " # images will automatically be resized and bucketed into the resolution specified\n", + " OrderedDict([\n", + " ('folder_path', '/content/dataset'),\n", + " ('caption_ext', 'txt'),\n", + " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n", + " ('shuffle_tokens', False), # shuffle caption order, split by commas\n", + " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n", + " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n", + " ])\n", + " ]),\n", + " ('train', OrderedDict([\n", + " ('batch_size', 1),\n", + " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n", + " ('gradient_accumulation_steps', 1),\n", + " ('train_unet', True),\n", + " ('train_text_encoder', False), # probably won't work with flux\n", + " ('content_or_style', 'balanced'), # content, style, balanced\n", + " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n", + " ('noise_scheduler', 'flowmatch'), # for training only\n", + " ('optimizer', 'adamw8bit'),\n", + " ('lr', 1e-4),\n", + "\n", + " # uncomment this to skip the pre training sample\n", + " # ('skip_first_sample', True),\n", + "\n", + " # uncomment to completely disable sampling\n", + " # ('disable_sampling', True),\n", + "\n", + " # uncomment to use new vell curved weighting. Experimental but may produce better results\n", + " # ('linear_timesteps', True),\n", + "\n", + " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n", + " ('ema_config', OrderedDict([\n", + " ('use_ema', True),\n", + " ('ema_decay', 0.99)\n", + " ])),\n", + "\n", + " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n", + " ('dtype', 'bf16')\n", + " ])),\n", + " ('model', OrderedDict([\n", + " # huggingface model name or path\n", + " ('name_or_path', 'black-forest-labs/FLUX.1-dev'),\n", + " ('is_flux', True),\n", + " ('quantize', True), # run 8bit mixed precision\n", + " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n", + " ])),\n", + " ('sample', OrderedDict([\n", + " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n", + " ('sample_every', 250), # sample every this many steps\n", + " ('width', 1024),\n", + " ('height', 1024),\n", + " ('prompts', [\n", + " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n", + " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n", + " 'woman with red hair, playing chess at the park, bomb going off in the background',\n", + " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n", + " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n", + " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n", + " 'a bear building a log cabin in the snow covered mountains',\n", + " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n", + " 'hipster man with a beard, building a chair, in a wood shop',\n", + " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n", + " 'a man holding a sign that says, \\'this is a sign\\'',\n", + " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n", + " ]),\n", + " ('neg', ''), # not used on flux\n", + " ('seed', 42),\n", + " ('walk_seed', True),\n", + " ('guidance_scale', 4),\n", + " ('sample_steps', 20)\n", + " ]))\n", + " ])\n", + " ])\n", + " ])),\n", + " # you can add any additional meta info here. [name] is replaced with config name at top\n", + " ('meta', OrderedDict([\n", + " ('name', '[name]'),\n", + " ('version', '1.0')\n", + " ]))\n", + "])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6F1FlM2Wb3l" + }, + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkajwI8gteOh" + }, + "outputs": [], + "source": [ + "run_job(job_to_run)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hblgb5uwW5SD" + }, + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/FLUX_1_schnell_LoRA_Training.ipynb b/notebooks/FLUX_1_schnell_LoRA_Training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..652d8ccc19f8996734785182ce8de46a5c7408fb --- /dev/null +++ b/notebooks/FLUX_1_schnell_LoRA_Training.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "zl-S0m3pkQC5" + }, + "source": [ + "# AI Toolkit by Ostris\n", + "## FLUX.1-schnell Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3cokMT-WC6rG" + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "BvAG0GKAh59G" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/ostris/ai-toolkit\n", + "!mkdir -p /content/dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UFUW4ZMmnp1V" + }, + "source": [ + "Put your image dataset in the `/content/dataset` folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OV0HnOI6o8V6" + }, + "source": [ + "## Model License\n", + "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n", + "\n", + "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n", + "\n", + "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3yZZdhFRoj2m" + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# Prompt for the token\n", + "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n", + "\n", + "# Set the environment variable\n", + "os.environ['HF_TOKEN'] = hf_token\n", + "\n", + "print(\"HF_TOKEN environment variable has been set.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "9gO2EzQ1kQC8" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image\n", + "import os\n", + "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8UUFzVRigbC" + }, + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "_t28QURYjRQO" + }, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict([\n", + " ('job', 'extension'),\n", + " ('config', OrderedDict([\n", + " # this name will be the folder and filename name\n", + " ('name', 'my_first_flux_lora_v1'),\n", + " ('process', [\n", + " OrderedDict([\n", + " ('type', 'sd_trainer'),\n", + " # root folder to save training sessions/samples/weights\n", + " ('training_folder', '/content/output'),\n", + " # uncomment to see performance stats in the terminal every N steps\n", + " #('performance_log_every', 1000),\n", + " ('device', 'cuda:0'),\n", + " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n", + " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n", + " # ('trigger_word', 'image'),\n", + " ('network', OrderedDict([\n", + " ('type', 'lora'),\n", + " ('linear', 16),\n", + " ('linear_alpha', 16)\n", + " ])),\n", + " ('save', OrderedDict([\n", + " ('dtype', 'float16'), # precision to save\n", + " ('save_every', 250), # save every this many steps\n", + " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n", + " ])),\n", + " ('datasets', [\n", + " # datasets are a folder of images. captions need to be txt files with the same name as the image\n", + " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n", + " # images will automatically be resized and bucketed into the resolution specified\n", + " OrderedDict([\n", + " ('folder_path', '/content/dataset'),\n", + " ('caption_ext', 'txt'),\n", + " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n", + " ('shuffle_tokens', False), # shuffle caption order, split by commas\n", + " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n", + " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n", + " ])\n", + " ]),\n", + " ('train', OrderedDict([\n", + " ('batch_size', 1),\n", + " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n", + " ('gradient_accumulation_steps', 1),\n", + " ('train_unet', True),\n", + " ('train_text_encoder', False), # probably won't work with flux\n", + " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n", + " ('noise_scheduler', 'flowmatch'), # for training only\n", + " ('optimizer', 'adamw8bit'),\n", + " ('lr', 1e-4),\n", + "\n", + " # uncomment this to skip the pre training sample\n", + " # ('skip_first_sample', True),\n", + "\n", + " # uncomment to completely disable sampling\n", + " # ('disable_sampling', True),\n", + "\n", + " # uncomment to use new vell curved weighting. Experimental but may produce better results\n", + " # ('linear_timesteps', True),\n", + "\n", + " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n", + " ('ema_config', OrderedDict([\n", + " ('use_ema', True),\n", + " ('ema_decay', 0.99)\n", + " ])),\n", + "\n", + " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n", + " ('dtype', 'bf16')\n", + " ])),\n", + " ('model', OrderedDict([\n", + " # huggingface model name or path\n", + " ('name_or_path', 'black-forest-labs/FLUX.1-schnell'),\n", + " ('assistant_lora_path', 'ostris/FLUX.1-schnell-training-adapter'), # Required for flux schnell training\n", + " ('is_flux', True),\n", + " ('quantize', True), # run 8bit mixed precision\n", + " # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary\n", + " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n", + " ])),\n", + " ('sample', OrderedDict([\n", + " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n", + " ('sample_every', 250), # sample every this many steps\n", + " ('width', 1024),\n", + " ('height', 1024),\n", + " ('prompts', [\n", + " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n", + " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n", + " 'woman with red hair, playing chess at the park, bomb going off in the background',\n", + " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n", + " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n", + " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n", + " 'a bear building a log cabin in the snow covered mountains',\n", + " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n", + " 'hipster man with a beard, building a chair, in a wood shop',\n", + " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n", + " 'a man holding a sign that says, \\'this is a sign\\'',\n", + " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n", + " ]),\n", + " ('neg', ''), # not used on flux\n", + " ('seed', 42),\n", + " ('walk_seed', True),\n", + " ('guidance_scale', 1), # schnell does not do guidance\n", + " ('sample_steps', 4) # 1 - 4 works well\n", + " ]))\n", + " ])\n", + " ])\n", + " ])),\n", + " # you can add any additional meta info here. [name] is replaced with config name at top\n", + " ('meta', OrderedDict([\n", + " ('name', '[name]'),\n", + " ('version', '1.0')\n", + " ]))\n", + "])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6F1FlM2Wb3l" + }, + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkajwI8gteOh" + }, + "outputs": [], + "source": [ + "run_job(job_to_run)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hblgb5uwW5SD" + }, + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/SliderTraining.ipynb b/notebooks/SliderTraining.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8465ec87dc2d2dce8f11e122c28c80297e3ea2b9 --- /dev/null +++ b/notebooks/SliderTraining.ipynb @@ -0,0 +1,339 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm", + "gpuType": "V100" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# AI Toolkit by Ostris\n", + "## Slider Training\n", + "\n", + "This is a quick colab demo for training sliders like can be found in my CivitAI profile https://civitai.com/user/Ostris/models . I will work on making it more user friendly, but for now, it will get you started." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/ostris/ai-toolkit" + ], + "metadata": { + "id": "BvAG0GKAh59G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image" + ], + "metadata": { + "collapsed": false + }, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ], + "metadata": { + "id": "N8UUFzVRigbC" + } + }, + { + "cell_type": "code", + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict({\n", + " # This is the config I use on my sliders, It is solid and tested\n", + " 'job': 'train',\n", + " 'config': {\n", + " # the name will be used to create a folder in the output folder\n", + " # it will also replace any [name] token in the rest of this config\n", + " 'name': 'detail_slider_v1',\n", + " # folder will be created with name above in folder below\n", + " # it can be relative to the project root or absolute\n", + " 'training_folder': \"output/LoRA\",\n", + " 'device': 'cuda', # cpu, cuda:0, etc\n", + " # for tensorboard logging, we will make a subfolder for this job\n", + " 'log_dir': \"output/.tensorboard\",\n", + " # you can stack processes for other jobs, It is not tested with sliders though\n", + " # just use one for now\n", + " 'process': [\n", + " {\n", + " 'type': 'slider', # tells runner to run the slider process\n", + " # network is the LoRA network for a slider, I recommend to leave this be\n", + " 'network': {\n", + " 'type': \"lora\",\n", + " # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good\n", + " 'linear': 8, # \"rank\" or \"dim\"\n", + " 'linear_alpha': 4, # Do about half of rank \"alpha\"\n", + " # 'conv': 4, # for convolutional layers \"locon\"\n", + " # 'conv_alpha': 4, # Do about half of conv \"alpha\"\n", + " },\n", + " # training config\n", + " 'train': {\n", + " # this is also used in sampling. Stick with ddpm unless you know what you are doing\n", + " 'noise_scheduler': \"ddpm\", # or \"ddpm\", \"lms\", \"euler_a\"\n", + " # how many steps to train. More is not always better. I rarely go over 1000\n", + " 'steps': 100,\n", + " # I have had good results with 4e-4 to 1e-4 at 500 steps\n", + " 'lr': 2e-4,\n", + " # enables gradient checkpoint, saves vram, leave it on\n", + " 'gradient_checkpointing': True,\n", + " # train the unet. I recommend leaving this true\n", + " 'train_unet': True,\n", + " # train the text encoder. I don't recommend this unless you have a special use case\n", + " # for sliders we are adjusting representation of the concept (unet),\n", + " # not the description of it (text encoder)\n", + " 'train_text_encoder': False,\n", + "\n", + " # just leave unless you know what you are doing\n", + " # also supports \"dadaptation\" but set lr to 1 if you use that,\n", + " # but it learns too fast and I don't recommend it\n", + " 'optimizer': \"adamw\",\n", + " # only constant for now\n", + " 'lr_scheduler': \"constant\",\n", + " # we randomly denoise random num of steps form 1 to this number\n", + " # while training. Just leave it\n", + " 'max_denoising_steps': 40,\n", + " # works great at 1. I do 1 even with my 4090.\n", + " # higher may not work right with newer single batch stacking code anyway\n", + " 'batch_size': 1,\n", + " # bf16 works best if your GPU supports it (modern)\n", + " 'dtype': 'bf16', # fp32, bf16, fp16\n", + " # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX\n", + " # although, the way we train sliders is comparative, so it probably won't work anyway\n", + " 'noise_offset': 0.0,\n", + " },\n", + "\n", + " # the model to train the LoRA network on\n", + " 'model': {\n", + " # name_or_path can be a hugging face name, local path or url to model\n", + " # on civit ai with or without modelVersionId. They will be cached in /model folder\n", + " # epicRealisim v5\n", + " 'name_or_path': \"https://civitai.com/models/25694?modelVersionId=134065\",\n", + " 'is_v2': False, # for v2 models\n", + " 'is_v_pred': False, # for v-prediction models (most v2 models)\n", + " # has some issues with the dual text encoder and the way we train sliders\n", + " # it works bit weights need to probably be higher to see it.\n", + " 'is_xl': False, # for SDXL models\n", + " },\n", + "\n", + " # saving config\n", + " 'save': {\n", + " 'dtype': 'float16', # precision to save. I recommend float16\n", + " 'save_every': 50, # save every this many steps\n", + " # this will remove step counts more than this number\n", + " # allows you to save more often in case of a crash without filling up your drive\n", + " 'max_step_saves_to_keep': 2,\n", + " },\n", + "\n", + " # sampling config\n", + " 'sample': {\n", + " # must match train.noise_scheduler, this is not used here\n", + " # but may be in future and in other processes\n", + " 'sampler': \"ddpm\",\n", + " # sample every this many steps\n", + " 'sample_every': 20,\n", + " # image size\n", + " 'width': 512,\n", + " 'height': 512,\n", + " # prompts to use for sampling. Do as many as you want, but it slows down training\n", + " # pick ones that will best represent the concept you are trying to adjust\n", + " # allows some flags after the prompt\n", + " # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive\n", + " # slide are good tests. will inherit sample.network_multiplier if not set\n", + " # --n [string] # negative prompt, will inherit sample.neg if not set\n", + " # Only 75 tokens allowed currently\n", + " # I like to do a wide positive and negative spread so I can see a good range and stop\n", + " # early if the network is braking down\n", + " 'prompts': [\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5\",\n", + " \"a golden retriever sitting on a leather couch, --m -5\",\n", + " \"a golden retriever sitting on a leather couch --m -3\",\n", + " \"a golden retriever sitting on a leather couch --m 3\",\n", + " \"a golden retriever sitting on a leather couch --m 5\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5\",\n", + " ],\n", + " # negative prompt used on all prompts above as default if they don't have one\n", + " 'neg': \"cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome\",\n", + " # seed for sampling. 42 is the answer for everything\n", + " 'seed': 42,\n", + " # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc\n", + " # will start over on next sample_every so s1 is always seed\n", + " # works well if you use same prompt but want different results\n", + " 'walk_seed': False,\n", + " # cfg scale (4 to 10 is good)\n", + " 'guidance_scale': 7,\n", + " # sampler steps (20 to 30 is good)\n", + " 'sample_steps': 20,\n", + " # default network multiplier for all prompts\n", + " # since we are training a slider, I recommend overriding this with --m [number]\n", + " # in the prompts above to get both sides of the slider\n", + " 'network_multiplier': 1.0,\n", + " },\n", + "\n", + " # logging information\n", + " 'logging': {\n", + " 'log_every': 10, # log every this many steps\n", + " 'use_wandb': False, # not supported yet\n", + " 'verbose': False, # probably done need unless you are debugging\n", + " },\n", + "\n", + " # slider training config, best for last\n", + " 'slider': {\n", + " # resolutions to train on. [ width, height ]. This is less important for sliders\n", + " # as we are not teaching the model anything it doesn't already know\n", + " # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1\n", + " # and [ 1024, 1024 ] for sd_xl\n", + " # you can do as many as you want here\n", + " 'resolutions': [\n", + " [512, 512],\n", + " # [ 512, 768 ]\n", + " # [ 768, 768 ]\n", + " ],\n", + " # slider training uses 4 combined steps for a single round. This will do it in one gradient\n", + " # step. It is highly optimized and shouldn't take anymore vram than doing without it,\n", + " # since we break down batches for gradient accumulation now. so just leave it on.\n", + " 'batch_full_slide': True,\n", + " # These are the concepts to train on. You can do as many as you want here,\n", + " # but they can conflict outweigh each other. Other than experimenting, I recommend\n", + " # just doing one for good results\n", + " 'targets': [\n", + " # target_class is the base concept we are adjusting the representation of\n", + " # for example, if we are adjusting the representation of a person, we would use \"person\"\n", + " # if we are adjusting the representation of a cat, we would use \"cat\" It is not\n", + " # a keyword necessarily but what the model understands the concept to represent.\n", + " # \"person\" will affect men, women, children, etc but will not affect cats, dogs, etc\n", + " # it is the models base general understanding of the concept and everything it represents\n", + " # you can leave it blank to affect everything. In this example, we are adjusting\n", + " # detail, so we will leave it blank to affect everything\n", + " {\n", + " 'target_class': \"\",\n", + " # positive is the prompt for the positive side of the slider.\n", + " # It is the concept that will be excited and amplified in the model when we slide the slider\n", + " # to the positive side and forgotten / inverted when we slide\n", + " # the slider to the negative side. It is generally best to include the target_class in\n", + " # the prompt. You want it to be the extreme of what you want to train on. For example,\n", + " # if you want to train on fat people, you would use \"an extremely fat, morbidly obese person\"\n", + " # as the prompt. Not just \"fat person\"\n", + " # max 75 tokens for now\n", + " 'positive': \"high detail, 8k, intricate, detailed, high resolution, high res, high quality\",\n", + " # negative is the prompt for the negative side of the slider and works the same as positive\n", + " # it does not necessarily work the same as a negative prompt when generating images\n", + " # these need to be polar opposites.\n", + " # max 76 tokens for now\n", + " 'negative': \"blurry, boring, fuzzy, low detail, low resolution, low res, low quality\",\n", + " # the loss for this target is multiplied by this number.\n", + " # if you are doing more than one target it may be good to set less important ones\n", + " # to a lower number like 0.1 so they don't outweigh the primary target\n", + " 'weight': 1.0,\n", + " },\n", + " ],\n", + " },\n", + " },\n", + " ]\n", + " },\n", + "\n", + " # You can put any information you want here, and it will be saved in the model.\n", + " # The below is an example, but you can put your grocery list in it if you want.\n", + " # It is saved in the model so be aware of that. The software will include this\n", + " # plus some other information for you automatically\n", + " 'meta': {\n", + " # [name] gets replaced with the name above\n", + " 'name': \"[name]\",\n", + " 'version': '1.0',\n", + " # 'creator': {\n", + " # 'name': 'your name',\n", + " # 'email': 'your@gmail.com',\n", + " # 'website': 'https://your.website'\n", + " # }\n", + " }\n", + "})\n" + ], + "metadata": { + "id": "_t28QURYjRQO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. Ill update soon." + ], + "metadata": { + "id": "h6F1FlM2Wb3l" + } + }, + { + "cell_type": "code", + "source": [ + "run_job(job_to_run)\n" + ], + "metadata": { + "id": "HkajwI8gteOh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ], + "metadata": { + "id": "Hblgb5uwW5SD" + } + } + ] +} diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ab621c1f2d4f545808723f8668b7c9276354c9d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +gradio +huggingface_hub +python-slugify +oyaml +modal +python-dotenv \ No newline at end of file diff --git a/requirements_local.txt b/requirements_local.txt new file mode 100644 index 0000000000000000000000000000000000000000..f45766b8d5b7eed9c0e0e9bcf9b0f5c0a3cc3588 --- /dev/null +++ b/requirements_local.txt @@ -0,0 +1,35 @@ +torch +torchvision +safetensors +git+https://github.com/huggingface/diffusers.git +transformers +lycoris-lora==1.8.3 +flatten_json +pyyaml +oyaml +tensorboard +kornia +invisible-watermark +einops +accelerate +toml +albumentations==1.4.15 +albucore==0.0.16 +pydantic +omegaconf +k-diffusion +open_clip_torch +timm +prodigyopt +controlnet_aux==0.0.7 +python-dotenv +bitsandbytes +hf_transfer +lpips +pytorch_fid +optimum-quanto==0.2.4 +sentencepiece +huggingface_hub +peft +gradio +python-slugify \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..6f13308117a84bc47a054c29d2088eec815572d8 --- /dev/null +++ b/run.py @@ -0,0 +1,90 @@ +import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys +from typing import Union, OrderedDict +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() + +sys.path.insert(0, os.getcwd()) +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' + +# check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + # set torch to trace mode + import torch + torch.autograd.set_detect_anomaly(True) +import argparse +from toolkit.job import get_job + + +def print_end_message(jobs_completed, jobs_failed): + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print("") + print("========================================") + print("Result:") + if len(completed_string) > 0: + print(f" - {completed_string}") + if len(failure_string) > 0: + print(f" - {failure_string}") + print("========================================") + + +def main(): + parser = argparse.ArgumentParser() + + # require at lease one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if failed job + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + + # flag to continue if failed job + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) + args = parser.parse_args() + + config_file_list = args.config_file_list + if len(config_file_list) == 0: + raise Exception("You must provide at least one config file") + + jobs_completed = 0 + jobs_failed = 0 + + print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + + for config_file in config_file_list: + try: + job = get_job(config_file, args.name) + job.run() + job.cleanup() + jobs_completed += 1 + except Exception as e: + print(f"Error running job: {e}") + jobs_failed += 1 + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + +if __name__ == '__main__': + main() diff --git a/run_modal.py b/run_modal.py new file mode 100644 index 0000000000000000000000000000000000000000..4675c1cb8ec709126317dcba02315177df777f68 --- /dev/null +++ b/run_modal.py @@ -0,0 +1,175 @@ +''' + +ostris/ai-toolkit on https://modal.com +Run training with the following command: +modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml + +''' + +import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys +import modal +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() + +sys.path.insert(0, "/root/ai-toolkit") +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' + +# define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes +# you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models +model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) + +# modal_output, due to "cannot mount volume on non-empty path" requirement +MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement + +# define modal app +image = ( + modal.Image.debian_slim(python_version="3.11") + # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app + .apt_install("libgl1", "libglib2.0-0") + .pip_install( + "python-dotenv", + "torch", + "diffusers[torch]", + "transformers", + "ftfy", + "torchvision", + "oyaml", + "opencv-python", + "albumentations", + "safetensors", + "lycoris-lora==1.8.3", + "flatten_json", + "pyyaml", + "tensorboard", + "kornia", + "invisible-watermark", + "einops", + "accelerate", + "toml", + "pydantic", + "omegaconf", + "k-diffusion", + "open_clip_torch", + "timm", + "prodigyopt", + "controlnet_aux==0.0.7", + "bitsandbytes", + "hf_transfer", + "lpips", + "pytorch_fid", + "optimum-quanto", + "sentencepiece", + "huggingface_hub", + "peft" + ) +) + +# mount for the entire ai-toolkit directory +# example: "/Users/username/ai-toolkit" is the local directory, "/root/ai-toolkit" is the remote directory +code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit") + +# create the Modal app with the necessary mounts and volumes +app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) + +# Check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + # Set torch to trace mode + import torch + torch.autograd.set_detect_anomaly(True) + +import argparse +from toolkit.job import get_job + +def print_end_message(jobs_completed, jobs_failed): + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print("") + print("========================================") + print("Result:") + if len(completed_string) > 0: + print(f" - {completed_string}") + if len(failure_string) > 0: + print(f" - {failure_string}") + print("========================================") + + +@app.function( + # request a GPU with at least 24GB VRAM + # more about modal GPU's: https://modal.com/docs/guide/gpu + gpu="A100", # gpu="H100" + # more about modal timeouts: https://modal.com/docs/guide/timeouts + timeout=7200 # 2 hours, increase or decrease if needed +) +def main(config_file_list_str: str, recover: bool = False, name: str = None): + # convert the config file list from a string to a list + config_file_list = config_file_list_str.split(",") + + jobs_completed = 0 + jobs_failed = 0 + + print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + + for config_file in config_file_list: + try: + job = get_job(config_file, name) + + job.config['process'][0]['training_folder'] = MOUNT_DIR + os.makedirs(MOUNT_DIR, exist_ok=True) + print(f"Training outputs will be saved to: {MOUNT_DIR}") + + # run the job + job.run() + + # commit the volume after training + model_volume.commit() + + job.cleanup() + jobs_completed += 1 + except Exception as e: + print(f"Error running job: {e}") + jobs_failed += 1 + if not recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + print_end_message(jobs_completed, jobs_failed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # require at least one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if a job fails + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + + # optional name replacement for config file + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) + args = parser.parse_args() + + # convert list of config files to a comma-separated string for Modal compatibility + config_file_list_str = ",".join(args.config_file_list) + + main.call(config_file_list_str=config_file_list_str, recover=args.recover, name=args.name) diff --git a/run_modal_from_hf.py b/run_modal_from_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..f839e5c23f1520024eaa2b46b44bf9a47537dc53 --- /dev/null +++ b/run_modal_from_hf.py @@ -0,0 +1,231 @@ +''' +ostris/ai-toolkit on https://modal.com +This module provides the Modal app and main function for training FLUX LoRA models. +The main() function is meant to be called from hf_ui.py, not run directly. +''' + +import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys +import modal +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() +import yaml +import traceback +import zipfile + +sys.path.insert(0, "/root/ai-toolkit") +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' +# Khai báo secrets +hf_secret = modal.Secret.from_name("huggingface-secret") +wandb_secret = modal.Secret.from_name("wandb-secret") + +# define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes +# you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models +model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) + +# modal_output, due to "cannot mount volume on non-empty path" requirement +MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement + +# define modal app +image = ( + modal.Image.debian_slim(python_version="3.11") + # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app + .apt_install("libgl1", "libglib2.0-0") + .pip_install( + "python-dotenv", + "torch", + "diffusers[torch]", + "transformers", + "ftfy", + "torchvision", + "oyaml", + "opencv-python", + "albumentations", + "safetensors", + "lycoris-lora==1.8.3", + "flatten_json", + "pyyaml", + "tensorboard", + "kornia", + "invisible-watermark", + "einops", + "accelerate", + "toml", + "pydantic", + "omegaconf", + "k-diffusion", + "open_clip_torch", + "timm", + "prodigyopt", + "controlnet_aux==0.0.7", + "bitsandbytes", + "hf_transfer", + "lpips", + "pytorch_fid", + "optimum-quanto", + "sentencepiece", + "huggingface_hub", + "peft", + "wandb", + ) +) + +# Mount từ thư mục gốc của HF Space +code_mount = modal.Mount.from_local_dir( + local_dir="/home/user/app", # Đường dẫn mặc định trong HF Space + remote_path="/root/ai-toolkit" +) + +# create the Modal app with the necessary mounts and volumes +app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) + +# Check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + # Set torch to trace mode + import torch + torch.autograd.set_detect_anomaly(True) + +import argparse +from toolkit.job import get_job +from toolkit.logging import WandbLogger + +def print_end_message(jobs_completed, jobs_failed): + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print("") + print("========================================") + print("Result:") + if len(completed_string) > 0: + print(f" - {completed_string}") + if len(failure_string) > 0: + print(f" - {failure_string}") + print("========================================") + + +@app.function( + # request a GPU with at least 24GB VRAM + # more about modal GPU's: https://modal.com/docs/guide/gpu + gpu="A100", # gpu="H100" + # more about modal timeouts: https://modal.com/docs/guide/timeouts + timeout=7200, # 2 hours, increase or decrease if needed + secrets=[hf_secret, wandb_secret] +) +def main(config_file_list_str: str, recover: bool = False, name: str = None): + # Các secrets sẽ tự động được inject vào environment variables + # os.environ["HF_TOKEN"] và os.environ["WANDB_API_KEY"] + + # convert the config file list from a string to a list + # config_file_list = config_file_list_str.split(",") + # convert the config string into a usable dict + config = None + try: + config = yaml.safe_load(config_file_list_str) + except Exception as e: + print(f"Error loading config file: {e}") + traceback.print_exc() + raise e + + jobs_completed = 0 + jobs_failed = 0 + + print(f"Running {config['config']['name']}") + + try: + # 1. validate config file to make sure required keys are present + if 'config' not in config: + raise ValueError("config file must have a `config` section") + if 'process' not in config['config']: + raise ValueError("config file must have a `process` section") + if len(config['config']['process']) == 0: + raise ValueError("config file must have at least one process") + if 'type' not in config['config']['process'][0]: + raise ValueError("config file process must have a `type`") + if 'training_folder' not in config['config']['process'][0]: + raise ValueError("config file process must have a `training_folder`") + if not config['config']['process'][0]['training_folder'].startswith("/root/ai-toolkit"): + raise ValueError("config file process training_folder path must start with /root/ai-toolkit") + + # find a dataset inside process object + datasets = config['config']['process'][0].get('datasets', None) + if datasets is not None and isinstance(datasets, list): + for dataset in datasets: + if 'folder_path' in dataset: + if not dataset['folder_path'].startswith('/root/ai-toolkit'): + raise ValueError("config file process dataset folder_path must start with /root/ai-toolkit") + + job = get_job(config, name) + + job.config['process'][0]['training_folder'] = MOUNT_DIR + os.makedirs(MOUNT_DIR, exist_ok=True) + print(f"Training outputs will be saved to: {MOUNT_DIR}") + + # setup wandb + if config['config']['process'][0]['logging']['use_wandb']: + wandb_token = os.environ.get('WANDB_API_KEY', None) + if wandb_token: + wandb_logger = WandbLogger( + project="flux-lora-training", + run_name=name, + config=job.raw_config, + ) + job.meta["wandb"] = wandb_logger.run.id + job.process[0].logger = wandb_logger + else: + print("WandB token not found, skipping WandB logging") + config['config']['process'][0]['logging']['use_wandb'] = False # disable if no key was given + + # handle dataset zip + datasets = config['config']['process'][0].get('datasets', None) + if datasets is not None and isinstance(datasets, list): + for dataset in datasets: + dataset_path = dataset.get('folder_path', None) + if dataset_path is not None: + # Kiểm tra xem trong folder có zip file không + for file in os.listdir(dataset_path): + if file.lower().endswith('.zip'): + zip_path = os.path.join(dataset_path, file) + # Tạo subfolder để extract + extract_path = os.path.join(dataset_path, 'extracted') + os.makedirs(extract_path, exist_ok=True) + + print(f"Extracting dataset zip file: {zip_path}") + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_path) + + # Cập nhật đường dẫn dataset trong config + dataset['folder_path'] = extract_path + # Xóa zip file sau khi extract + os.remove(zip_path) + print(f"Dataset extracted to: {extract_path}") + break + + # run the job + job.run() + + if config['config']['process'][0]['logging']['use_wandb']: + wandb_logger.finish() + + # commit the volume after training + model_volume.commit() + + job.cleanup() + jobs_completed += 1 + + except Exception as e: + print(f"Error running job: {e}") + if 'response' in e.__dict__: + print(f" - Response code: {e.response.status_code} text: {e.response.text}") + jobs_failed += 1 + traceback.print_exc() + if not recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + print_end_message(jobs_completed, jobs_failed) \ No newline at end of file diff --git a/scripts/convert_cog.py b/scripts/convert_cog.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4f6e73c1d3e444583319b37557ad36ec988ccf --- /dev/null +++ b/scripts/convert_cog.py @@ -0,0 +1,128 @@ +import json +from collections import OrderedDict +import os +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +device = torch.device('cpu') + +# [diffusers] -> kohya +embedding_mapping = { + 'text_encoders_0': 'clip_l', + 'text_encoders_1': 'clip_g' +} + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps') +sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json') + +# load keymap +with open(sdxl_keymap_path, 'r') as f: + ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap'] + +# invert the item / key pairs +diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()} + + +def get_ldm_key(diffuser_key): + diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}" + diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight') + diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight') + diffuser_key = diffuser_key.replace('_alpha', '.alpha') + diffuser_key = diffuser_key.replace('_processor_to_', '_to_') + diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.') + if diffuser_key in diffusers_ldm_keymap: + return diffusers_ldm_keymap[diffuser_key] + else: + raise KeyError(f"Key {diffuser_key} not found in keymap") + + +def convert_cog(lora_path, embedding_path): + embedding_state_dict = OrderedDict() + lora_state_dict = OrderedDict() + + # # normal dict + # normal_dict = OrderedDict() + # example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors" + # with safe_open(example_path, framework="pt", device='cpu') as f: + # keys = list(f.keys()) + # for key in keys: + # normal_dict[key] = f.get_tensor(key) + + with safe_open(embedding_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + for key in keys: + new_key = embedding_mapping[key] + embedding_state_dict[new_key] = f.get_tensor(key) + + with safe_open(lora_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + lora_rank = None + + # get the lora dim first. Check first 3 linear layers just to be safe + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + num_checked = 0 + if len(tensor.shape) == 2: + this_dim = min(tensor.shape) + if lora_rank is None: + lora_rank = this_dim + elif lora_rank != this_dim: + raise ValueError(f"lora rank is not consistent, got {tensor.shape}") + else: + num_checked += 1 + if num_checked >= 3: + break + + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + if new_key.endswith('.lora_down.weight'): + alpha_key = new_key.replace('.lora_down.weight', '.alpha') + # diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims + # assume first smallest dim is the lora rank if shape is 2 + lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank + + lora_state_dict[new_key] = tensor + + return lora_state_dict, embedding_state_dict + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + 'lora_path', + type=str, + help='Path to lora file' + ) + parser.add_argument( + 'embedding_path', + type=str, + help='Path to embedding file' + ) + + parser.add_argument( + '--lora_output', + type=str, + default="lora_output", + ) + + parser.add_argument( + '--embedding_output', + type=str, + default="embedding_output", + ) + + args = parser.parse_args() + + lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path) + + # save them + save_file(lora_state_dict, args.lora_output) + save_file(embedding_state_dict, args.embedding_output) + print(f"Saved lora to {args.lora_output}") + print(f"Saved embedding to {args.embedding_output}") diff --git a/scripts/convert_lora_to_peft_format.py b/scripts/convert_lora_to_peft_format.py new file mode 100644 index 0000000000000000000000000000000000000000..3034db646ce0cbf784940df17a45e2468063f485 --- /dev/null +++ b/scripts/convert_lora_to_peft_format.py @@ -0,0 +1,91 @@ +# currently only works with flux as support is not quite there yet + +import argparse +import os.path +from collections import OrderedDict + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +args = parser.parse_args() +args.input_path = os.path.abspath(args.input_path) +args.output_path = os.path.abspath(args.output_path) + +from safetensors.torch import load_file, save_file + +meta = OrderedDict() +meta['format'] = 'pt' + +state_dict = load_file(args.input_path) + +# peft doesnt have an alpha so we need to scale the weights +alpha_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux +] + +# keys where the rank is in the first dimension +rank_idx0_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight' + # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight' +] + +alpha = None +rank = None + +for key in rank_idx0_keys: + if key in state_dict: + rank = int(state_dict[key].shape[0]) + break + +if rank is None: + raise ValueError(f'Could not find rank in state dict') + +for key in alpha_keys: + if key in state_dict: + alpha = int(state_dict[key]) + break + +if alpha is None: + # set to rank if not found + alpha = rank + + +up_multiplier = alpha / rank + +new_state_dict = {} + +for key, value in state_dict.items(): + if key.endswith('.alpha'): + continue + + orig_dtype = value.dtype + + new_val = value.float() * up_multiplier + + new_key = key + new_key = new_key.replace('lora_transformer_', 'transformer.') + for i in range(100): + new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + new_key = new_key.replace('_lora', '.lora') + new_key = new_key.replace('attn_', 'attn.') + new_key = new_key.replace('ff_', 'ff.') + new_key = new_key.replace('context_net_', 'context.net.') + new_key = new_key.replace('0_proj', '0.proj') + new_key = new_key.replace('norm_linear', 'norm.linear') + new_key = new_key.replace('norm_out_linear', 'norm_out.linear') + new_key = new_key.replace('to_out_', 'to_out.') + + new_state_dict[new_key] = new_val.to(orig_dtype) + +save_file(new_state_dict, args.output_path, meta) +print(f'Saved to {args.output_path}') diff --git a/scripts/generate_sampler_step_scales.py b/scripts/generate_sampler_step_scales.py new file mode 100644 index 0000000000000000000000000000000000000000..11efb3183becb48ec4a485565d53049fb6a8d11c --- /dev/null +++ b/scripts/generate_sampler_step_scales.py @@ -0,0 +1,20 @@ +import argparse +import torch +import os +from diffusers import StableDiffusionPipeline +import sys + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# add project root to path +sys.path.append(PROJECT_ROOT) + +SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales') + + +parser = argparse.ArgumentParser(description='Process some images.') +add_arg = parser.add_argument +add_arg('--model', type=str, required=True, help='Path to model') +add_arg('--sampler', type=str, required=True, help='Name of sampler') + +args = parser.parse_args() + diff --git a/scripts/make_diffusers_model.py b/scripts/make_diffusers_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4536a9215540dd01321ef1426665db9d6ef6347f --- /dev/null +++ b/scripts/make_diffusers_model.py @@ -0,0 +1,61 @@ +import argparse +from collections import OrderedDict +import sys +import os +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_DIR) + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/scripts/make_lcm_sdxl_model.py b/scripts/make_lcm_sdxl_model.py new file mode 100644 index 0000000000000000000000000000000000000000..20e95ce795a39fe2837b80fcbf1950c256ad4c59 --- /dev/null +++ b/scripts/make_lcm_sdxl_model.py @@ -0,0 +1,67 @@ +import argparse +from collections import OrderedDict + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + +if args.sdxl: + adapter_id = "latent-consistency/lcm-lora-sdxl" +if args.refiner: + adapter_id = "latent-consistency/lcm-lora-sdxl" +elif args.ssd: + adapter_id = "latent-consistency/lcm-lora-ssd-1b" +else: + adapter_id = "latent-consistency/lcm-lora-sdv1-5" + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.load_lora_weights(adapter_id) +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/scripts/patch_te_adapter.py b/scripts/patch_te_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7249a46d8e566c3889538c465359e6c66b1c9602 --- /dev/null +++ b/scripts/patch_te_adapter.py @@ -0,0 +1,42 @@ +import torch +from safetensors.torch import save_file, load_file +from collections import OrderedDict +meta = OrderedDict() +meta["format"] ="pt" + +attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors") +state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors") + +attn_list = [] +for key, value in state_dict.items(): + if "attn1" in key: + attn_list.append(key) + +attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'] + +adapter_names = [] +for i in range(100): + if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict: + adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter") + + +for i in range(len(adapter_names)): + adapter_name = adapter_names[i] + attn_name = attn_names[i] + adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight' + adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight' + state_k_name = attn_name.replace(".processor", ".to_k.weight") + state_v_name = attn_name.replace(".processor", ".to_v.weight") + if adapter_k_name in attn_dict: + state_dict[state_k_name] = attn_dict[adapter_k_name] + state_dict[state_v_name] = attn_dict[adapter_v_name] + else: + print("adapter_k_name", adapter_k_name) + print("state_k_name", state_k_name) + +for key, value in state_dict.items(): + state_dict[key] = value.cpu().to(torch.float16) + +save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta) + +print("Done") diff --git a/scripts/repair_dataset_folder.py b/scripts/repair_dataset_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9d277508c19046b5737620a01b9eba09635e98 --- /dev/null +++ b/scripts/repair_dataset_folder.py @@ -0,0 +1,65 @@ +import argparse +from PIL import Image +from PIL.ImageOps import exif_transpose +from tqdm import tqdm +import os + +parser = argparse.ArgumentParser(description='Process some images.') +parser.add_argument("input_folder", type=str, help="Path to folder containing images") + +args = parser.parse_args() + +img_types = ['.jpg', '.jpeg', '.png', '.webp'] + +# find all images in the input folder +images = [] +for root, _, files in os.walk(args.input_folder): + for file in files: + if file.lower().endswith(tuple(img_types)): + images.append(os.path.join(root, file)) +print(f"Found {len(images)} images") + +num_skipped = 0 +num_repaired = 0 +num_deleted = 0 + +pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image") +for img_path in images: + filename = os.path.basename(img_path) + filename_no_ext, file_extension = os.path.splitext(filename) + # if it is jpg, ignore + if file_extension.lower() == '.jpg': + num_skipped += 1 + pbar.update(1) + + continue + + try: + img = Image.open(img_path) + except Exception as e: + print(f"Error opening {img_path}: {e}") + # delete it + os.remove(img_path) + num_deleted += 1 + pbar.update(1) + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + continue + + + try: + img = exif_transpose(img) + except Exception as e: + print(f"Error rotating {img_path}: {e}") + + new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg') + + img = img.convert("RGB") + img.save(new_path, quality=95) + # remove the old file + os.remove(img_path) + num_repaired += 1 + pbar.update(1) + # update pbar + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + +print("Done") \ No newline at end of file diff --git a/testing/compare_keys.py b/testing/compare_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4f95203fe1024daeb66ddd79696875f04578c7 --- /dev/null +++ b/testing/compare_keys.py @@ -0,0 +1,99 @@ +import argparse +import os + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file +from collections import OrderedDict +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument( + 'file_2', + nargs='+', + type=str, + help='Path to second safe tensor file' +) + +args = parser.parse_args() + +find_matches = False + +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +state_dict_file_2 = load_file(args.file_2[0]) +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + + +json_data = { + "both": keys_in_both, + "not_in_state_dict_2": keys_not_in_state_dict_2, + "not_in_state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') +state_dict_1_filename = os.path.basename(args.file_1[0]) +state_dict_2_filename = os.path.basename(args.file_2[0]) +# save key names for each in own file +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_1_keys, indent=4)) + +with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_2_keys, indent=4)) + + +with open(json_save_path, 'w') as f: + f.write(json_data) \ No newline at end of file diff --git a/testing/generate_lora_mapping.py b/testing/generate_lora_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..e632d2a662f6a6498a8d340074ea9c9a27ac431a --- /dev/null +++ b/testing/generate_lora_mapping.py @@ -0,0 +1,130 @@ +from collections import OrderedDict + +import torch +from safetensors.torch import load_file +import argparse +import os +import json + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json') + +# load keymap +with open(keymap_path, 'r') as f: + keymap = json.load(f) + +lora_keymap = OrderedDict() + +# convert keymap to lora key naming +for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items(): + if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'): + # skip it + continue + # sdxl has same te for locon with kohya and ours + if ldm_key.startswith('conditioner'): + #skip it + continue + # ignore vae + if ldm_key.startswith('first_stage_model'): + continue + ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_') + ldm_key = ldm_key.replace('.weight', '') + ldm_key = ldm_key.replace('.', '_') + + diffusers_key = diffusers_key.replace('unet_', 'lora_unet_') + diffusers_key = diffusers_key.replace('.weight', '') + diffusers_key = diffusers_key.replace('.', '_') + + lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha" + lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight" + lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight" + + +parser = argparse.ArgumentParser() +parser.add_argument("input", help="input file") +parser.add_argument("input2", help="input2 file") + +args = parser.parse_args() + +# name = args.name +# if args.sdxl: +# name += '_sdxl' +# elif args.sd2: +# name += '_sd2' +# else: +# name += '_sd1' +name = 'stable_diffusion_locon_sdxl' + +locon_save = load_file(args.input) +our_save = load_file(args.input2) + +our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys())) +locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys())) + +print(f"we have {len(our_extra_keys)} extra keys") +print(f"locon has {len(locon_extra_keys)} extra keys") + +save_dtype = torch.float16 +print(f"our extra keys: {our_extra_keys}") +print(f"locon extra keys: {locon_extra_keys}") + + +def export_state_dict(our_save): + converted_state_dict = OrderedDict() + for key, value in our_save.items(): + # test encoders share keys for some reason + if key.startswith('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + for ldm_key, diffusers_key in lora_keymap.items(): + if converted_key == diffusers_key: + converted_key = ldm_key + + converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) + return converted_state_dict + +def import_state_dict(loaded_state_dict): + converted_state_dict = OrderedDict() + for key, value in loaded_state_dict.items(): + if key.startswith('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + for ldm_key, diffusers_key in lora_keymap.items(): + if converted_key == ldm_key: + converted_key = diffusers_key + + converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) + return converted_state_dict + + +# check it again +converted_state_dict = export_state_dict(our_save) +converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys())) +locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys())) + + +print(f"we have {len(converted_extra_keys)} extra keys") +print(f"locon has {len(locon_extra_keys)} extra keys") + +print(f"our extra keys: {converted_extra_keys}") + +# convert back +cycle_state_dict = import_state_dict(converted_state_dict) +cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys())) +our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys())) + +print(f"we have {len(our_extra_keys)} extra keys") +print(f"cycle has {len(cycle_extra_keys)} extra keys") + +# save keymap +to_save = OrderedDict() +to_save['ldm_diffusers_keymap'] = lora_keymap + +with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f: + json.dump(to_save, f, indent=4) + + + diff --git a/testing/generate_weight_mappings.py b/testing/generate_weight_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..346fe09d5c98a22a3c06ac9ae1dadb549a196193 --- /dev/null +++ b/testing/generate_weight_mappings.py @@ -0,0 +1,479 @@ +import argparse +import gc +import os +import re +import os +# add project root to sys path +import sys + +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + +KEYMAPS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'toolkit', 'keymaps') + +device = torch.device('cpu') +dtype = torch.float32 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +def get_reduced_shape(shape_tuple): + # iterate though shape anr remove 1s + new_shape = [] + for dim in shape_tuple: + if dim != 1: + new_shape.append(dim) + return tuple(new_shape) + + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make') +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--vega', action='store_true', help='is vega model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() + +file_path = args.file_1[0] + +find_matches = False + +print(f'Loading diffusers model') + +ignore_ldm_begins_with = [] + +diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1] +if args.ssd: + diffusers_file_path = "segmind/SSD-1B" +if args.vega: + diffusers_file_path = "segmind/Segmind-Vega" + +# if args.refiner: +# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0" + +if not args.refiner: + + diffusers_model_config = ModelConfig( + name_or_path=diffusers_file_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + is_vega=args.vega, + dtype=dtype, + ) + diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, + ) + diffusers_sd.load_model() + # delete things we dont need + del diffusers_sd.tokenizer + flush() + + print(f'Loading ldm model') + diffusers_state_dict = diffusers_sd.state_dict() +else: + # refiner wont work directly with stable diffusion + # so we need to load the model and then load the state dict + diffusers_pipeline = StableDiffusionXLPipeline.from_single_file( + diffusers_file_path, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ).to(device) + # diffusers_pipeline = StableDiffusionXLPipeline.from_single_file( + # file_path, + # torch_dtype=torch.float16, + # use_safetensors=True, + # variant="fp16", + # ).to(device) + + SD_PREFIX_VAE = "vae" + SD_PREFIX_UNET = "unet" + SD_PREFIX_REFINER_UNET = "refiner_unet" + SD_PREFIX_TEXT_ENCODER = "te" + + SD_PREFIX_TEXT_ENCODER1 = "te0" + SD_PREFIX_TEXT_ENCODER2 = "te1" + + diffusers_state_dict = OrderedDict() + for k, v in diffusers_pipeline.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + diffusers_state_dict[new_key] = v + for k, v in diffusers_pipeline.text_encoder_2.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" + diffusers_state_dict[new_key] = v + for k, v in diffusers_pipeline.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + diffusers_state_dict[new_key] = v + + # add ignore ones as we are only going to focus on unet and copy the rest + # ignore_ldm_begins_with = ["conditioner.", "first_stage_model."] + +diffusers_dict_keys = list(diffusers_state_dict.keys()) + +ldm_state_dict = load_file(file_path) +ldm_dict_keys = list(ldm_state_dict.keys()) + +ldm_diffusers_keymap = OrderedDict() +ldm_diffusers_shape_map = OrderedDict() +ldm_operator_map = OrderedDict() +diffusers_operator_map = OrderedDict() + +total_keys = len(ldm_dict_keys) + +matched_ldm_keys = [] +matched_diffusers_keys = [] + +error_margin = 1e-8 + +tmp_merge_key = "TMP___MERGE" + +te_suffix = '' +proj_pattern_weight = None +proj_pattern_bias = None +text_proj_layer = None +if args.sdxl or args.ssd or args.vega: + te_suffix = '1' + ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks" + proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "conditioner.embedders.1.model.text_projection" +if args.refiner: + te_suffix = '1' + ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks" + proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "conditioner.embedders.0.model.text_projection" +if args.sd2: + te_suffix = '' + ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks" + proj_pattern_weight = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "cond_stage_model.model.text_projection" + +if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega: + if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0]) + elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0]) + elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0]) + else: + d_model = 1024 + + # do pre known merging + for ldm_key in ldm_dict_keys: + try: + match = re.match(proj_pattern_weight, ldm_key) + if match: + if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": + print("here") + number = int(match.group(1)) + new_val = torch.cat([ + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"], + ], dim=0) + # add to matched so we dont check them + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight") + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight") + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight") + # make diffusers convertable_dict + diffusers_state_dict[ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val + + # add operator + ldm_operator_map[ldm_key] = { + "cat": [ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight", + ], + } + + matched_ldm_keys.append(ldm_key) + + # text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + # text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :] + # text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :] + + # add diffusers operators + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"0:{d_model}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"{d_model}:{d_model * 2}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"{d_model * 2}:, :" + ] + } + + match = re.match(proj_pattern_bias, ldm_key) + if match: + number = int(match.group(1)) + new_val = torch.cat([ + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"], + ], dim=0) + # add to matched so we dont check them + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias") + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias") + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias") + # make diffusers convertable_dict + diffusers_state_dict[ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val + + # add operator + ldm_operator_map[ldm_key] = { + "cat": [ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias", + ], + } + + matched_ldm_keys.append(ldm_key) + + # add diffusers operators + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"0:{d_model}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"{d_model}:{d_model * 2}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"{d_model * 2}:, :" + ] + } + except Exception as e: + print(f"Error on key {ldm_key}") + print(e) + + # update keys + diffusers_dict_keys = list(diffusers_state_dict.keys()) + +pbar = tqdm(ldm_dict_keys, desc='Matching ldm-diffusers keys', total=total_keys) +# run through all weights and check mse between them to find matches +for ldm_key in ldm_dict_keys: + ldm_shape_tuple = ldm_state_dict[ldm_key].shape + ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple) + for diffusers_key in diffusers_dict_keys: + if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": + print("here") + + diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape + diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple) + + # That was easy. Same key + # if ldm_key == diffusers_key: + # ldm_diffusers_keymap[ldm_key] = diffusers_key + # matched_ldm_keys.append(ldm_key) + # matched_diffusers_keys.append(diffusers_key) + # break + + # if we already have this key mapped, skip it + if diffusers_key in matched_diffusers_keys: + continue + + # if reduced shapes do not match skip it + if ldm_reduced_shape_tuple != diffusers_reduced_shape_tuple: + continue + + ldm_weight = ldm_state_dict[ldm_key] + did_reduce_ldm = False + diffusers_weight = diffusers_state_dict[diffusers_key] + did_reduce_diffusers = False + + # reduce the shapes to match if they are not the same + if ldm_shape_tuple != ldm_reduced_shape_tuple: + ldm_weight = ldm_weight.view(ldm_reduced_shape_tuple) + did_reduce_ldm = True + + if diffusers_shape_tuple != diffusers_reduced_shape_tuple: + diffusers_weight = diffusers_weight.view(diffusers_reduced_shape_tuple) + did_reduce_diffusers = True + + # check to see if they match within a margin of error + mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float()) + if mse < error_margin: + ldm_diffusers_keymap[ldm_key] = diffusers_key + matched_ldm_keys.append(ldm_key) + matched_diffusers_keys.append(diffusers_key) + + if did_reduce_ldm or did_reduce_diffusers: + ldm_diffusers_shape_map[ldm_key] = (ldm_shape_tuple, diffusers_shape_tuple) + if did_reduce_ldm: + del ldm_weight + if did_reduce_diffusers: + del diffusers_weight + flush() + + break + + pbar.update(1) + +pbar.close() + +name = args.name +if args.sdxl: + name += '_sdxl' +elif args.ssd: + name += '_ssd' +elif args.vega: + name += '_vega' +elif args.refiner: + name += '_refiner' +elif args.sd2: + name += '_sd2' +else: + name += '_sd1' + +# if len(matched_ldm_keys) != len(matched_diffusers_keys): +unmatched_ldm_keys = [x for x in ldm_dict_keys if x not in matched_ldm_keys] +unmatched_diffusers_keys = [x for x in diffusers_dict_keys if x not in matched_diffusers_keys] +# has unmatched keys + +has_unmatched_keys = len(unmatched_ldm_keys) > 0 or len(unmatched_diffusers_keys) > 0 + + +def get_slices_from_string(s: str) -> tuple: + slice_strings = s.split(',') + slices = [eval(f"slice({component.strip()})") for component in slice_strings] + return tuple(slices) + + +if has_unmatched_keys: + + print( + f"Found {len(unmatched_ldm_keys)} unmatched ldm keys and {len(unmatched_diffusers_keys)} unmatched diffusers keys") + + unmatched_obj = OrderedDict() + unmatched_obj['ldm'] = OrderedDict() + unmatched_obj['diffusers'] = OrderedDict() + + print(f"Gathering info on unmatched keys") + + for key in tqdm(unmatched_ldm_keys, desc='Unmatched LDM keys'): + # get min, max, mean, std + weight = ldm_state_dict[key] + weight_min = weight.min().item() + weight_max = weight.max().item() + unmatched_obj['ldm'][key] = { + 'shape': weight.shape, + "min": weight_min, + "max": weight_max, + } + del weight + flush() + + for key in tqdm(unmatched_diffusers_keys, desc='Unmatched Diffusers keys'): + # get min, max, mean, std + weight = diffusers_state_dict[key] + weight_min = weight.min().item() + weight_max = weight.max().item() + unmatched_obj['diffusers'][key] = { + "shape": weight.shape, + "min": weight_min, + "max": weight_max, + } + del weight + flush() + + unmatched_path = os.path.join(KEYMAPS_FOLDER, f'{name}_unmatched.json') + with open(unmatched_path, 'w') as f: + f.write(json.dumps(unmatched_obj, indent=4)) + + print(f'Saved unmatched keys to {unmatched_path}') + +# save ldm remainders +remaining_ldm_values = OrderedDict() +for key in unmatched_ldm_keys: + remaining_ldm_values[key] = ldm_state_dict[key].detach().to('cpu', torch.float16) + +save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors')) +print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}') + +# do cleanup of some left overs and bugs +to_remove = [] +for ldm_key, diffusers_key in ldm_diffusers_keymap.items(): + # get rid of tmp merge keys used to slicing + if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key: + to_remove.append(ldm_key) + +for key in to_remove: + del ldm_diffusers_keymap[key] + +to_remove = [] +# remove identical shape mappings. Not sure why they exist but they do +for ldm_key, shape_list in ldm_diffusers_shape_map.items(): + # remove identical shape mappings. Not sure why they exist but they do + # convert to json string to make it easier to compare + ldm_shape = json.dumps(shape_list[0]) + diffusers_shape = json.dumps(shape_list[1]) + if ldm_shape == diffusers_shape: + to_remove.append(ldm_key) + +for key in to_remove: + del ldm_diffusers_shape_map[key] + +dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json') +save_obj = OrderedDict() +save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap +save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map +save_obj["ldm_diffusers_operator_map"] = ldm_operator_map +save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map +with open(dest_path, 'w') as f: + f.write(json.dumps(save_obj, indent=4)) + +print(f'Saved keymap to {dest_path}') diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a2983c82469a2c6e56874df8b184d30f2d23fc --- /dev/null +++ b/testing/merge_in_text_encoder_adapter.py @@ -0,0 +1,180 @@ +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json + +# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" +# te_path = "google/flan-t5-xl" +# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" +# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" +model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch" +te_path = "google/flan-t5-xl" +te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors" +output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1" + + +print("Loading te adapter") +te_aug_sd = load_file(te_aug_path) + +print("Loading model") +is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path) + +# if "pixart" in model_path.lower(): +is_pixart = "pixart" in model_path.lower() + +pipeline_class = StableDiffusionPipeline + +# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16) + +if is_pixart: + pipeline_class = PixArtSigmaPipeline + +if is_diffusers: + sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) +else: + sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) + +print("Loading Text Encoder") +# Load the text encoder +te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16) + +# patch it +sd.text_encoder = te +sd.tokenizer = T5Tokenizer.from_pretrained(te_path) + +if is_pixart: + unet = sd.transformer + unet_sd = sd.transformer.state_dict() +else: + unet = sd.unet + unet_sd = sd.unet.state_dict() + + +if is_pixart: + weight_idx = 0 +else: + weight_idx = 1 + +new_cross_attn_dim = None + +# count the num of params in state dict +start_params = sum([v.numel() for v in unet_sd.values()]) + +print("Building") +attn_processor_keys = [] +if is_pixart: + transformer: Transformer2DModel = unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") +else: + attn_processor_keys = list(unet.attn_processors.keys()) + +for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith( + "attn1") else \ + unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + pass + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + te_aug_name = None + while True: + if is_pixart: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + else: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + if f"{te_aug_name}.weight" in te_aug_sd: + # increment so we dont redo it next time + weight_idx += 1 + break + else: + weight_idx += 1 + + if weight_idx > 1000: + raise ValueError("Could not find the next weight") + + orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape) + new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape) + orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape) + new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape) + + unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"] + unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"] + + if new_cross_attn_dim is None: + new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1] + + + +if is_pixart: + # copy the caption_projection weight + del unet_sd['caption_projection.linear_1.bias'] + del unet_sd['caption_projection.linear_1.weight'] + del unet_sd['caption_projection.linear_2.bias'] + del unet_sd['caption_projection.linear_2.weight'] + +print("Saving unmodified model") +sd = sd.to("cpu", torch.float16) +sd.save_pretrained( + output_path, + safe_serialization=True, +) + +# overwrite the unet +if is_pixart: + unet_folder = os.path.join(output_path, "transformer") +else: + unet_folder = os.path.join(output_path, "unet") + +# move state_dict to cpu +unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()} + +meta = OrderedDict() +meta["format"] = "pt" + +print("Patching") + +save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta) + +# load the json file +with open(os.path.join(unet_folder, "config.json"), 'r') as f: + config = json.load(f) + +config['cross_attention_dim'] = new_cross_attn_dim + +if is_pixart: + config['caption_channels'] = None + +# save it +with open(os.path.join(unet_folder, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + +print("Done") + +new_params = sum([v.numel() for v in unet_sd.values()]) + +# print new and old params with , formatted +print(f"Old params: {start_params:,}") +print(f"New params: {new_params:,}") diff --git a/testing/shrink_pixart.py b/testing/shrink_pixart.py new file mode 100644 index 0000000000000000000000000000000000000000..ad27b1a0ea38612a2a4202261ca88a7875281db1 --- /dev/null +++ b/testing/shrink_pixart.py @@ -0,0 +1,62 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 + +current_idx = 0 +for i in range(28): + if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]: + # todo merge in with previous block + for name in block_names: + try: + new_state_dict_key = name.format(idx=current_idx - 1) + old_state_dict_key = name.format(idx=i) + new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) + except KeyError: + raise KeyError(f"KeyError: {name.format(idx=current_idx)}") + else: + for name in block_names: + new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] + current_idx += 1 + + +# make sure they are all fp16 and on cpu +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/testing/shrink_pixart2.py b/testing/shrink_pixart2.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c30cf87f38610ac23b31afdc94311fba8e3a41 --- /dev/null +++ b/testing/shrink_pixart2.py @@ -0,0 +1,81 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# Blocks to keep +# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27] +keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27] + + +def weighted_merge(kept_block, removed_block, weight): + return kept_block * (1 - weight) + removed_block * weight + + +# First, copy all kept blocks to new_state_dict +for i, old_idx in enumerate(keep_blocks): + for name in block_names: + old_key = name.format(idx=old_idx) + new_key = name.format(idx=i) + new_state_dict[new_key] = state_dict[old_key].clone() + +# Then, merge information from removed blocks +for i in range(28): + if i not in keep_blocks: + # Find the nearest kept blocks + prev_kept = max([b for b in keep_blocks if b < i]) + next_kept = min([b for b in keep_blocks if b > i]) + + # Calculate the weight based on position + weight = (i - prev_kept) / (next_kept - prev_kept) + + for name in block_names: + removed_key = name.format(idx=i) + prev_new_key = name.format(idx=keep_blocks.index(prev_kept)) + next_new_key = name.format(idx=keep_blocks.index(next_kept)) + + # Weighted merge for previous kept block + new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight) + + # Weighted merge for next kept block + new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key], + 1 - weight) + +# Convert to fp16 and move to CPU +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# Save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/testing/shrink_pixart_sm.py b/testing/shrink_pixart_sm.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea07bf154fdd653f1a928f3c553dc56580a828 --- /dev/null +++ b/testing/shrink_pixart_sm.py @@ -0,0 +1,84 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + flattened = weight.view(-1, original_shape[-1]) + + if flattened.shape[1] <= target_size: + return weight + + U, S, V = torch.svd(flattened) + reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size])) + + if reduced.shape[1] < target_size: + padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device) + reduced = torch.cat((reduced, padding), dim=1) + + return reduced.view(original_shape[:-1] + (target_size,)) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + original_size = bias.shape[0] + + if original_size <= target_size: + return torch.nn.functional.pad(bias, (0, target_size - original_size)) + else: + return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +source_hidden_size = 1152 +target_hidden_size = 1024 + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] + + if len(value.shape) > 1 and value.shape[ + 1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = value[:, :target_hidden_size] + elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4: + value = value[:, :target_hidden_size * 4] + + elif 'bias' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") diff --git a/testing/shrink_pixart_sm2.py b/testing/shrink_pixart_sm2.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3304dfc72e50e445fce27ae793e7544009aa1a --- /dev/null +++ b/testing/shrink_pixart_sm2.py @@ -0,0 +1,110 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + + if len(original_shape) == 1: + # For 1D tensors, simply truncate + return weight[:target_size] + + if original_shape[0] <= target_size: + return weight + + # Reshape the tensor to 2D + flattened = weight.reshape(original_shape[0], -1) + + # Perform SVD + U, S, V = torch.svd(flattened) + + # Reduce the dimensions + reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t()) + + # Reshape back to the original shape with reduced first dimension + new_shape = (target_size,) + original_shape[1:] + return reduced.reshape(new_shape) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + return bias[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/testing/shrink_pixart_sm3.py b/testing/shrink_pixart_sm3.py new file mode 100644 index 0000000000000000000000000000000000000000..b8756aec45b4a5cb59315ab11a5bed320d74f7ba --- /dev/null +++ b/testing/shrink_pixart_sm3.py @@ -0,0 +1,100 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + # resize so target_size is the first dimension + tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1]) + + # use interpolate to resize the tensor + new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True) + + # reshape back to original shape + return new_weight.view(target_size, weight.shape[1]) + + +def reduce_bias(bias, target_size): + bias = bias.view(1, 1, bias.shape[0], 1) + + new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True) + + return new_bias.view(target_size) + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..31d97f2d92949de05dbc831cbdbdb764a5997dca --- /dev/null +++ b/testing/test_bucket_dataloader.py @@ -0,0 +1,128 @@ +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +import sys +import os +import cv2 +import random +from transformers import CLIPImageProcessor + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from toolkit.paths import SD_SCRIPTS_ROOT +import torchvision.transforms.functional +from toolkit.image_utils import show_img, show_tensors + +sys.path.append(SD_SCRIPTS_ROOT) + +from library.model_util import load_vae +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ + trigger_dataloader_setup_epoch +from toolkit.config_modules import DatasetConfig +import argparse +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument('dataset_folder', type=str, default='input') +parser.add_argument('--epochs', type=int, default=1) + + + +args = parser.parse_args() + +dataset_folder = args.dataset_folder +resolution = 1024 +bucket_tolerance = 64 +batch_size = 1 + +clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") + +class FakeAdapter: + def __init__(self): + self.clip_image_processor = clip_processor + + +## make fake sd +class FakeSD: + def __init__(self): + self.adapter = FakeAdapter() + + + + +dataset_config = DatasetConfig( + dataset_path=dataset_folder, + # clip_image_path=dataset_folder, + # square_crop=True, + resolution=resolution, + # caption_ext='json', + default_caption='default', + # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', + buckets=True, + bucket_tolerance=bucket_tolerance, + # poi='person', + # shuffle_augmentations=True, + # augmentations=[ + # { + # 'method': 'Posterize', + # 'num_bits': [(0, 4), (0, 4), (0, 4)], + # 'p': 1.0 + # }, + # + # ] +) + +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) + + +# run through an epoch ang check sizes +dataloader_iterator = iter(dataloader) +for epoch in range(args.epochs): + for batch in tqdm(dataloader): + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor + batch_size, channels, height, width = img_batch.shape + + # img_batch = color_block_imgs(img_batch, neg1_1=True) + + # chunks = torch.chunk(img_batch, batch_size, dim=0) + # # put them so they are size by side + # big_img = torch.cat(chunks, dim=3) + # big_img = big_img.squeeze(0) + # + # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) + # big_control_img = torch.cat(control_chunks, dim=3) + # big_control_img = big_control_img.squeeze(0) * 2 - 1 + # + # + # # resize control image + # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) + # + # big_img = torch.cat([big_img, big_control_img], dim=2) + # + # min_val = big_img.min() + # max_val = big_img.max() + # + # big_img = (big_img / 2 + 0.5).clamp(0, 1) + + big_img = img_batch + # big_img = big_img.clamp(-1, 1) + + show_tensors(big_img) + + # convert to image + # img = transforms.ToPILImage()(big_img) + # + # show_img(img) + + time.sleep(0.2) + # if not last epoch + if epoch < args.epochs - 1: + trigger_dataloader_setup_epoch(dataloader) + +cv2.destroyAllWindows() + +print('done') diff --git a/testing/test_model_load_save.py b/testing/test_model_load_save.py new file mode 100644 index 0000000000000000000000000000000000000000..87bdfb3ef8246268f0660db6bf24822c74506c45 --- /dev/null +++ b/testing/test_model_load_save.py @@ -0,0 +1,172 @@ +import argparse +import os +# add project root to sys path +import sys + +from tqdm import tqdm + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file +from collections import OrderedDict +import json + +from toolkit.config_modules import ModelConfig +from toolkit.paths import KEYMAPS_ROOT +from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers +from toolkit.stable_diffusion_model import StableDiffusion + +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path an LDM model' +) + +parser.add_argument( + '--is_xl', + action='store_true', + help='Is the model an XL model' +) + +parser.add_argument( + '--is_v2', + action='store_true', + help='Is the model a v2 model' +) + +args = parser.parse_args() + +find_matches = False + +print("Loading model") +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +print("Loading model into diffusers format") +model_config = ModelConfig( + name_or_path=args.file_1[0], + is_xl=args.is_xl +) +sd = StableDiffusion( + model_config=model_config, + device=device, +) +sd.load_model() + +# load our base +base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors') +mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json') + +print("Converting model back to LDM") +version_string = '1' +if args.is_v2: + version_string = '2' +if args.is_xl: + version_string = 'sdxl' +# convert the state dict +state_dict_file_2 = get_ldm_state_dict_from_diffusers( + sd.state_dict(), + version_string, + device='cpu', + dtype=dtype +) + +# state_dict_file_2 = load_file(args.file_2[0]) + +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + +if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0: + print("All keys match!") + print("Checking values...") + mismatch_keys = [] + loss = torch.nn.MSELoss() + tolerance = 1e-6 + for key in tqdm(keys_in_both): + if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance: + print(f"Values for key {key} don't match!") + print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}") + mismatch_keys.append(key) + + if len(mismatch_keys) == 0: + print("All values match!") + else: + print("Some valued font match!") + print(mismatch_keys) + mismatched_path = os.path.join(project_root, 'config', 'mismatch.json') + with open(mismatched_path, 'w') as f: + f.write(json.dumps(mismatch_keys, indent=4)) + exit(0) + +else: + print("Keys don't match!, generating info...") + +json_data = { + "both": keys_in_both, + "not_in_state_dict_2": keys_not_in_state_dict_2, + "not_in_state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + + +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') +state_dict_1_filename = os.path.basename(args.file_1[0]) +# state_dict_2_filename = os.path.basename(args.file_2[0]) +# save key names for each in own file +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_1_keys, indent=4)) + +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f: + f.write(json.dumps(state_dict_2_keys, indent=4)) + +with open(json_save_path, 'w') as f: + f.write(json_data) diff --git a/testing/test_vae.py b/testing/test_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..44b31f6311024833b1dc4c46cf431f7ae68f09d7 --- /dev/null +++ b/testing/test_vae.py @@ -0,0 +1,113 @@ +import argparse +import os +from PIL import Image +import torch +from torchvision.transforms import Resize, ToTensor +from diffusers import AutoencoderKL +from pytorch_fid import fid_score +from skimage.metrics import peak_signal_noise_ratio as psnr +import lpips +from tqdm import tqdm +from torchvision import transforms + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def load_images(folder_path): + images = [] + for filename in os.listdir(folder_path): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + img_path = os.path.join(folder_path, filename) + images.append(img_path) + return images + + +def paramiter_count(model): + state_dict = model.state_dict() + paramiter_count = 0 + for key in state_dict: + paramiter_count += torch.numel(state_dict[key]) + return int(paramiter_count) + + +def calculate_metrics(vae, images, max_imgs=-1): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vae = vae.to(device) + lpips_model = lpips.LPIPS(net='alex').to(device) + + rfid_scores = [] + psnr_scores = [] + lpips_scores = [] + + # transform = transforms.Compose([ + # transforms.Resize(256, antialias=True), + # transforms.CenterCrop(256) + # ]) + # needs values between -1 and 1 + to_tensor = ToTensor() + + if max_imgs > 0 and len(images) > max_imgs: + images = images[:max_imgs] + + for img_path in tqdm(images): + try: + img = Image.open(img_path).convert('RGB') + # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device) + img_tensor = to_tensor(img).unsqueeze(0).to(device) + img_tensor = 2 * img_tensor - 1 + # if width or height is not divisible by 8, crop it + if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0: + img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8] + + except Exception as e: + print(f"Error processing {img_path}: {e}") + continue + + + with torch.no_grad(): + reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample + + # Calculate rFID + # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed) + # rfid_scores.append(rfid) + + # Calculate PSNR + psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy()) + psnr_scores.append(psnr_val) + + # Calculate LPIPS + lpips_val = lpips_model(img_tensor, reconstructed).item() + lpips_scores.append(lpips_val) + + # avg_rfid = sum(rfid_scores) / len(rfid_scores) + avg_rfid = 0 + avg_psnr = sum(psnr_scores) / len(psnr_scores) + avg_lpips = sum(lpips_scores) / len(lpips_scores) + + return avg_rfid, avg_psnr, avg_lpips + + +def main(): + parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions") + parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") + parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + args = parser.parse_args() + + if os.path.isfile(args.vae_path): + vae = AutoencoderKL.from_single_file(args.vae_path) + else: + vae = AutoencoderKL.from_pretrained(args.vae_path) + vae.eval() + vae = vae.to(device) + print(f"Model has {paramiter_count(vae)} parameters") + images = load_images(args.image_folder) + + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) + + # print(f"Average rFID: {avg_rfid}") + print(f"Average PSNR: {avg_psnr}") + print(f"Average LPIPS: {avg_lpips}") + + +if __name__ == "__main__": + main() diff --git a/testing/test_vae_cycle.py b/testing/test_vae_cycle.py new file mode 100644 index 0000000000000000000000000000000000000000..175e8f8fa5cdb4cb652225f4d95e7a2cbb04fd29 --- /dev/null +++ b/testing/test_vae_cycle.py @@ -0,0 +1,112 @@ +import os + +import torch +from safetensors.torch import load_file +from collections import OrderedDict +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 +vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors' + +find_matches = False + +state_dict_ldm = load_file(vae_path) +diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device) + +ldm_keys = state_dict_ldm.keys() + +matched_keys = {} +duplicated_keys = { + +} + +if find_matches: + # find values that match with a very low mse + for ldm_key in ldm_keys: + ldm_value = state_dict_ldm[ldm_key] + for diffusers_key in list(diffusers_vae.state_dict().keys()): + diffusers_value = diffusers_vae.state_dict()[diffusers_key] + if diffusers_key in vae_keys_squished_on_diffusers: + diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1) + # if they are not same shape, skip + if ldm_value.shape != diffusers_value.shape: + continue + mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value) + if mse < 1e-6: + if ldm_key in list(matched_keys.keys()): + print(f'{ldm_key} already matched to {matched_keys[ldm_key]}') + if ldm_key in duplicated_keys: + duplicated_keys[ldm_key].append(diffusers_key) + else: + duplicated_keys[ldm_key] = [diffusers_key] + continue + matched_keys[ldm_key] = diffusers_key + is_matched = True + break + + print(f'Found {len(matched_keys)} matches') + +dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae) +dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys()) +keys_in_both = [] + +keys_not_in_diffusers = [] +for key in ldm_keys: + if key not in dif_to_ldm_state_dict_keys: + keys_not_in_diffusers.append(key) + +keys_not_in_ldm = [] +for key in dif_to_ldm_state_dict_keys: + if key not in ldm_keys: + keys_not_in_ldm.append(key) + +keys_in_both = [] +for key in ldm_keys: + if key in dif_to_ldm_state_dict_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_diffusers.sort() +keys_not_in_ldm.sort() +keys_in_both.sort() + +# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}') +# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}') +# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}') + +json_data = { + "both": keys_in_both, + "ldm": keys_not_in_diffusers, + "diffusers": keys_not_in_ldm +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_ldm: + remaining_diffusers_values[key] = dif_to_ldm_state_dict[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_diffusers: + remaining_ldm_values[key] = state_dict_ldm[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') + +with open(json_save_path, 'w') as f: + f.write(json_data) +if find_matches: + with open(json_matched_save_path, 'w') as f: + f.write(json.dumps(matched_keys, indent=4)) + with open(json_duped_save_path, 'w') as f: + f.write(json.dumps(duplicated_keys, indent=4)) diff --git a/toolkit/__init__.py b/toolkit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/toolkit/assistant_lora.py b/toolkit/assistant_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeca968ad6c6d5b403f3786eea81efc33944c94 --- /dev/null +++ b/toolkit/assistant_lora.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING +from toolkit.config_modules import NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from safetensors.torch import load_file + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork: + if not sd.is_flux: + raise ValueError("Only Flux models can load assistant adapters currently.") + pipe = sd.pipeline + print(f"Loading assistant adapter from {adapter_path}") + adapter_name = adapter_path.split("/")[-1].split(".")[0] + lora_state_dict = load_file(adapter_path) + + linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0]) + # linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item()) + linear_alpha = linear_dim + transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict + # get dim and scale + network_config = NetworkConfig( + linear=linear_dim, + linear_alpha=linear_alpha, + transformer_only=transformer_only, + ) + + network = LoRASpecialNetwork( + text_encoder=pipe.text_encoder, + unet=pipe.transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + is_flux=True, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_assistant_adapter=True + ) + network.apply_to( + pipe.text_encoder, + pipe.transformer, + apply_text_encoder=False, + apply_unet=True + ) + network.force_to(sd.device_torch, dtype=sd.torch_dtype) + network.eval() + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + network.is_active = True + + return network diff --git a/toolkit/basic.py b/toolkit/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..0d32a9d2f356bbda1d3e629c6ebc6688b5f1d458 --- /dev/null +++ b/toolkit/basic.py @@ -0,0 +1,56 @@ +import gc + +import torch + + +def value_map(inputs, min_in, max_in, min_out, max_out): + return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out + + +def flush(garbage_collect=True): + torch.cuda.empty_cache() + if garbage_collect: + gc.collect() + + +def get_mean_std(tensor): + if len(tensor.shape) == 3: + tensor = tensor.unsqueeze(0) + elif len(tensor.shape) != 4: + raise Exception("Expected tensor of shape (batch_size, channels, width, height)") + mean, variance = torch.mean( + tensor, dim=[2, 3], keepdim=True + ), torch.var( + tensor, dim=[2, 3], + keepdim=True + ) + std = torch.sqrt(variance + 1e-5) + return mean, std + + +def adain(content_features, style_features): + # Assumes that the content and style features are of shape (batch_size, channels, width, height) + + dims = [2, 3] + if len(content_features.shape) == 3: + # content_features = content_features.unsqueeze(0) + # style_features = style_features.unsqueeze(0) + dims = [1] + + # Step 1: Calculate mean and variance of content features + content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features, + dim=dims, + keepdim=True) + # Step 2: Calculate mean and variance of style features + style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims, + keepdim=True) + + # Step 3: Normalize content features + content_std = torch.sqrt(content_var + 1e-5) + normalized_content = (content_features - content_mean) / content_std + + # Step 4: Scale and shift normalized content with style's statistics + style_std = torch.sqrt(style_var + 1e-5) + stylized_content = normalized_content * style_std + style_mean + + return stylized_content diff --git a/toolkit/buckets.py b/toolkit/buckets.py new file mode 100644 index 0000000000000000000000000000000000000000..835c9eb96f5bb38bb6d871530b89cb835dc47091 --- /dev/null +++ b/toolkit/buckets.py @@ -0,0 +1,174 @@ +from typing import Type, List, Union, TypedDict + + +class BucketResolution(TypedDict): + width: int + height: int + + +# resolutions SDXL was trained on with a 1024x1024 base resolution +resolutions_1024: List[BucketResolution] = [ + # SDXL Base resolution + {"width": 1024, "height": 1024}, + # SDXL Resolutions, widescreen + {"width": 2048, "height": 512}, + {"width": 1984, "height": 512}, + {"width": 1920, "height": 512}, + {"width": 1856, "height": 512}, + {"width": 1792, "height": 576}, + {"width": 1728, "height": 576}, + {"width": 1664, "height": 576}, + {"width": 1600, "height": 640}, + {"width": 1536, "height": 640}, + {"width": 1472, "height": 704}, + {"width": 1408, "height": 704}, + {"width": 1344, "height": 704}, + {"width": 1344, "height": 768}, + {"width": 1280, "height": 768}, + {"width": 1216, "height": 832}, + {"width": 1152, "height": 832}, + {"width": 1152, "height": 896}, + {"width": 1088, "height": 896}, + {"width": 1088, "height": 960}, + {"width": 1024, "height": 960}, + # SDXL Resolutions, portrait + {"width": 960, "height": 1024}, + {"width": 960, "height": 1088}, + {"width": 896, "height": 1088}, + {"width": 896, "height": 1152}, # 2:3 + {"width": 832, "height": 1152}, + {"width": 832, "height": 1216}, + {"width": 768, "height": 1280}, + {"width": 768, "height": 1344}, + {"width": 704, "height": 1408}, + {"width": 704, "height": 1472}, + {"width": 640, "height": 1536}, + {"width": 640, "height": 1600}, + {"width": 576, "height": 1664}, + {"width": 576, "height": 1728}, + {"width": 576, "height": 1792}, + {"width": 512, "height": 1856}, + {"width": 512, "height": 1920}, + {"width": 512, "height": 1984}, + {"width": 512, "height": 2048}, + # extra wides + {"width": 8192, "height": 128}, + {"width": 128, "height": 8192}, +] + +# Even numbers so they can be patched easier +resolutions_dit_1024: List[BucketResolution] = [ + # Base resolution + {"width": 1024, "height": 1024}, + # widescreen + {"width": 2048, "height": 512}, + {"width": 1792, "height": 576}, + {"width": 1728, "height": 576}, + {"width": 1664, "height": 576}, + {"width": 1600, "height": 640}, + {"width": 1536, "height": 640}, + {"width": 1472, "height": 704}, + {"width": 1408, "height": 704}, + {"width": 1344, "height": 704}, + {"width": 1344, "height": 768}, + {"width": 1280, "height": 768}, + {"width": 1216, "height": 832}, + {"width": 1152, "height": 832}, + {"width": 1152, "height": 896}, + {"width": 1088, "height": 896}, + {"width": 1088, "height": 960}, + {"width": 1024, "height": 960}, + # portrait + {"width": 960, "height": 1024}, + {"width": 960, "height": 1088}, + {"width": 896, "height": 1088}, + {"width": 896, "height": 1152}, # 2:3 + {"width": 832, "height": 1152}, + {"width": 832, "height": 1216}, + {"width": 768, "height": 1280}, + {"width": 768, "height": 1344}, + {"width": 704, "height": 1408}, + {"width": 704, "height": 1472}, + {"width": 640, "height": 1536}, + {"width": 640, "height": 1600}, + {"width": 576, "height": 1664}, + {"width": 576, "height": 1728}, + {"width": 576, "height": 1792}, + {"width": 512, "height": 1856}, + {"width": 512, "height": 1920}, + {"width": 512, "height": 1984}, + {"width": 512, "height": 2048}, +] + + +def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]: + # determine scaler form 1024 to resolution + scaler = resolution / 1024 + + bucket_size_list = [] + for bucket in resolutions_1024: + # must be divisible by 8 + width = int(bucket["width"] * scaler) + height = int(bucket["height"] * scaler) + if width % divisibility != 0: + width = width - (width % divisibility) + if height % divisibility != 0: + height = height - (height % divisibility) + bucket_size_list.append({"width": width, "height": height}) + + return bucket_size_list + + +def get_resolution(width, height): + num_pixels = width * height + # determine same number of pixels for square image + square_resolution = int(num_pixels ** 0.5) + return square_resolution + + +def get_bucket_for_image_size( + width: int, + height: int, + bucket_size_list: List[BucketResolution] = None, + resolution: Union[int, None] = None, + divisibility: int = 8 +) -> BucketResolution: + + if bucket_size_list is None and resolution is None: + # get resolution from width and height + resolution = get_resolution(width, height) + if bucket_size_list is None: + # if real resolution is smaller, use that instead + real_resolution = get_resolution(width, height) + resolution = min(resolution, real_resolution) + bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility) + + # Check for exact match first + for bucket in bucket_size_list: + if bucket["width"] == width and bucket["height"] == height: + return bucket + + # If exact match not found, find the closest bucket + closest_bucket = None + min_removed_pixels = float("inf") + + for bucket in bucket_size_list: + scale_w = bucket["width"] / width + scale_h = bucket["height"] / height + + # To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped. + scale = max(scale_w, scale_h) + + new_width = int(width * scale) + new_height = int(height * scale) + + removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width + + if removed_pixels < min_removed_pixels: + min_removed_pixels = removed_pixels + closest_bucket = bucket + + if closest_bucket is None: + raise ValueError("No suitable bucket found") + + return closest_bucket diff --git a/toolkit/civitai.py b/toolkit/civitai.py new file mode 100644 index 0000000000000000000000000000000000000000..ef505ad833f951470eb2e6a9c7b26059a6509604 --- /dev/null +++ b/toolkit/civitai.py @@ -0,0 +1,217 @@ +from toolkit.paths import MODELS_PATH +import requests +import os +import json +import tqdm + + +class ModelCache: + def __init__(self): + self.raw_cache = {} + self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json') + if os.path.exists(self.cache_path): + with open(self.cache_path, 'r') as f: + all_cache = json.load(f) + if 'models' in all_cache: + self.raw_cache = all_cache['models'] + else: + self.raw_cache = all_cache + + def get_model_path(self, model_id: int, model_version_id: int = None): + if str(model_id) not in self.raw_cache: + return None + if model_version_id is None: + # get latest version + model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()]) + if model_version_id is None: + return None + model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path'] + # check if model path exists + if not os.path.exists(model_path): + # remove version from cache + del self.raw_cache[str(model_id)][str(model_version_id)] + self.save() + return None + return model_path + else: + if str(model_version_id) not in self.raw_cache[str(model_id)]: + return None + model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path'] + # check if model path exists + if not os.path.exists(model_path): + # remove version from cache + del self.raw_cache[str(model_id)][str(model_version_id)] + self.save() + return None + return model_path + + def update_cache(self, model_id: int, model_version_id: int, model_path: str): + if str(model_id) not in self.raw_cache: + self.raw_cache[str(model_id)] = {} + if str(model_version_id) not in self.raw_cache[str(model_id)]: + self.raw_cache[str(model_id)][str(model_version_id)] = {} + self.raw_cache[str(model_id)][str(model_version_id)] = { + 'model_path': model_path + } + self.save() + + def save(self): + if not os.path.exists(os.path.dirname(self.cache_path)): + os.makedirs(os.path.dirname(self.cache_path), exist_ok=True) + all_cache = {'models': {}} + if os.path.exists(self.cache_path): + # load it first + with open(self.cache_path, 'r') as f: + all_cache = json.load(f) + + all_cache['models'] = self.raw_cache + + with open(self.cache_path, 'w') as f: + json.dump(all_cache, f, indent=2) + + +def get_model_download_info(model_id: int, model_version_id: int = None): + # curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \ + # -H "Content-Type: application/json" \ + # -X GET + print( + f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}") + endpoint = f"https://civitai.com/api/v1/models/{model_id}" + + # get the json + response = requests.get(endpoint) + response.raise_for_status() + model_data = response.json() + + model_version = None + + # go through versions and get the top one if one is not set + for version in model_data['modelVersions']: + if model_version_id is not None: + if str(version['id']) == str(model_version_id): + model_version = version + break + else: + # get first version + model_version = version + break + + if model_version is None: + raise ValueError( + f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}") + + model_file = None + # go through files and prefer fp16 safetensors + # "metadata": { + # "fp": "fp16", + # "size": "pruned", + # "format": "SafeTensor" + # }, + # todo check pickle scans and skip if not good + # try to get fp16 safetensor + for file in model_version['files']: + if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor': + model_file = file + break + + if model_file is None: + # try to get primary + for file in model_version['files']: + if file['primary']: + model_file = file + break + + if model_file is None: + # try to get any safetensor + for file in model_version['files']: + if file['metadata']['format'] == 'SafeTensor': + model_file = file + break + + if model_file is None: + # try to get any fp16 + for file in model_version['files']: + if file['metadata']['fp'] == 'fp16': + model_file = file + break + + if model_file is None: + # try to get any + for file in model_version['files']: + model_file = file + break + + if model_file is None: + raise ValueError(f"Could not find a model file to download for model id: {model_id}") + + return model_file, model_version['id'] + + +def get_model_path_from_url(url: str): + # get query params form url if they are set + # https: // civitai.com / models / 25694?modelVersionId = 127742 + query_params = {} + if '?' in url: + query_string = url.split('?')[1] + query_params = dict(qc.split("=") for qc in query_string.split("&")) + + # get model id from url + model_id = url.split('/')[-1] + # remove query params from model id + if '?' in model_id: + model_id = model_id.split('?')[0] + if model_id.isdigit(): + model_id = int(model_id) + else: + raise ValueError(f"Invalid model id: {model_id}") + + model_cache = ModelCache() + model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None)) + if model_path is not None: + return model_path + else: + # download model + file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None)) + + download_url = file_info['downloadUrl'] # url does not work directly + size_kb = file_info['sizeKB'] + filename = file_info['name'] + model_path = os.path.join(MODELS_PATH, filename) + + # download model + print(f"Did not find model locally, downloading from model from: {download_url}") + + # use tqdm to show status of downlod + response = requests.get(download_url, stream=True) + response.raise_for_status() + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}") + os.makedirs(os.path.dirname(model_path), exist_ok=True) + # remove tmp file if it exists + if os.path.exists(tmp_path): + os.remove(tmp_path) + + try: + + with open(tmp_path, 'wb') as f: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + f.write(data) + progress_bar.close() + # move to final path + os.rename(tmp_path, model_path) + model_cache.update_cache(model_id, model_version_id, model_path) + + return model_path + except Exception as e: + # remove tmp file + os.remove(tmp_path) + raise e + + +# if is main +if __name__ == '__main__': + model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742") + print(model_path) diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc920caac68aa43a2b1ddc944079d88feb50a4 --- /dev/null +++ b/toolkit/clip_vision_adapter.py @@ -0,0 +1,406 @@ +from typing import TYPE_CHECKING, Mapping, Any + +import torch +import weakref + +from toolkit.config_modules import AdapterConfig +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule +from toolkit.prompt_utils import PromptEmbeds +from toolkit.train_tools import get_torch_dtype + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel +) + +from toolkit.resampler import Resampler + +import torch.nn as nn + + +class Embedder(nn.Module): + def __init__( + self, + num_input_tokens: int = 1, + input_dim: int = 1024, + num_output_tokens: int = 8, + output_dim: int = 768, + mid_dim: int = 1024 + ): + super(Embedder, self).__init__() + self.num_output_tokens = num_output_tokens + self.num_input_tokens = num_input_tokens + self.input_dim = input_dim + self.output_dim = output_dim + + self.layer_norm = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, mid_dim) + self.gelu = nn.GELU() + # self.fc2 = nn.Linear(mid_dim, mid_dim) + self.fc2 = nn.Linear(mid_dim, mid_dim) + + self.fc2.weight.data.zero_() + + self.layer_norm2 = nn.LayerNorm(mid_dim) + self.fc3 = nn.Linear(mid_dim, mid_dim) + self.gelu2 = nn.GELU() + self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens) + + # set the weights to 0 + self.fc3.weight.data.zero_() + self.fc4.weight.data.zero_() + + + # self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) + # self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) + + def forward(self, x): + if len(x.shape) == 2: + x = x.unsqueeze(1) + x = self.layer_norm(x) + x = self.fc1(x) + x = self.gelu(x) + x = self.fc2(x) + x = self.layer_norm2(x) + x = self.fc3(x) + x = self.gelu2(x) + x = self.fc4(x) + + x = x.view(-1, self.num_output_tokens, self.output_dim) + + return x + + +class ClipVisionAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig): + super().__init__() + self.config = adapter_config + self.trigger = adapter_config.trigger + self.trigger_class_name = adapter_config.trigger_class_name + self.sd_ref: weakref.ref = weakref.ref(sd) + # embedding stuff + self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder] + self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer] + placeholder_tokens = [self.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.config.num_tokens): + additional_tokens.append(f"{self.trigger}_{i}") + placeholder_tokens += additional_tokens + + # handle dual tokenizer + self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [ + self.sd_ref().tokenizer] + self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [ + self.sd_ref().text_encoder] + + self.placeholder_token_ids = [] + self.embedding_tokens = [] + + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.config.num_tokens} tokens to tokenizer") + + + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.config.num_tokens: + raise ValueError( + f"The tokenizer already contains the token {self.trigger}. Please pass a different" + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.config.num_tokens: + init_token_ids = init_token_ids[:self.config.num_tokens] + elif len(init_token_ids) < self.config.num_tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids)) + + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.device = self.sd_ref().unet.device + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + self.config.image_encoder_path, + ignore_mismatched_sizes=True + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.train_image_encoder: + self.image_encoder.train() + else: + self.image_encoder.eval() + + # max_seq_len = CLIP tokens + CLS token + image_encoder_state_dict = self.image_encoder.state_dict() + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.target_hidden_size + + if self.config.clip_layer == 'image_embeds': + in_tokens = 1 + embedding_dim = self.image_encoder.config.projection_dim + + self.embedder = Embedder( + num_output_tokens=self.config.num_tokens, + num_input_tokens=in_tokens, + input_dim=embedding_dim, + output_dim=self.sd_ref().unet.config['cross_attention_dim'], + mid_dim=embedding_dim * self.config.num_tokens, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + + self.embedder.train() + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + state_dict = { + 'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + } + if self.config.train_image_encoder: + state_dict['image_encoder'] = self.image_encoder.state_dict( + *args, destination=destination, prefix=prefix, + keep_vars=keep_vars) + + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + self.embedder.load_state_dict(state_dict["embedder"], strict=strict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + + def parameters(self, *args, **kwargs): + yield from self.embedder.parameters(*args, **kwargs) + + def named_parameters(self, *args, **kwargs): + yield from self.embedder.named_parameters(*args, **kwargs) + + def get_clip_image_embeds_from_tensors( + self, tensors_0_1: torch.Tensor, drop=False, + is_training=False, + has_been_preprocessed=False + ) -> torch.Tensor: + with torch.no_grad(): + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + with torch.set_grad_enabled(is_training): + if is_training: + self.image_encoder.train() + else: + self.image_encoder.eval() + clip_output = self.image_encoder(clip_image, output_hidden_states=True) + + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + clip_image_embeds = clip_output.image_embeds + return clip_image_embeds + + import torch + + def set_vec(self, new_vector, text_encoder_idx=0): + # Get the embedding layer + embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings() + + # Indices to replace in the embeddings + indices_to_replace = self.placeholder_token_ids[text_encoder_idx] + + # Replace the specified embeddings with new_vector + for idx in indices_to_replace: + vector_idx = idx - indices_to_replace[0] + embedding_layer.weight[idx] = new_vector[vector_idx] + + # adds it to the tokenizer + def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if clip_image_embeds.ndim == 2: + # expand the token dimension + clip_image_embeds = clip_image_embeds.unsqueeze(1) + image_prompt_embeds = self.embedder(clip_image_embeds) + # todo add support for multiple batch sizes + if image_prompt_embeds.shape[0] != 1: + raise ValueError("Batch size must be 1 for embedder for now") + + # output on sd1.5 is bs, num_tokens, 768 + if len(self.text_encoder_list) == 1: + # add it to the text encoder + self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) + elif len(self.text_encoder_list) == 2: + if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \ + image_prompt_embeds.shape[2]: + raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") + # sdxl variants + # image_prompt_embeds = 2048 + # te1 = 768 + # te2 = 1280 + te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size] + te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:] + self.set_vec(te1_embeds[0], text_encoder_idx=0) + self.set_vec(te2_embeds[0], text_encoder_idx=1) + else: + + raise ValueError("Unsupported number of text encoders") + # just a place to put a breakpoint + pass + + def restore_embeddings(self): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip( + self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids + ): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ + min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + with torch.no_grad(): + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] + # detach it all + text_encoder.get_input_embeddings().weight.detach_() + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] + + replace_with = embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + + # reverses injection with class name. useful for normalizations + def inject_trigger_class_name_into_prompt(self, prompt): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + + default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger] + + replace_with = self.config.trigger_class_name + to_replace_list = default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/toolkit/config.py b/toolkit/config.py new file mode 100644 index 0000000000000000000000000000000000000000..52de47b836540c319e3ca8faa479312619139769 --- /dev/null +++ b/toolkit/config.py @@ -0,0 +1,110 @@ +import os +import json +from typing import Union + +import oyaml as yaml +import re +from collections import OrderedDict + +from toolkit.paths import TOOLKIT_ROOT + +possible_extensions = ['.json', '.jsonc', '.yaml', '.yml'] + + +def get_cwd_abs_path(path): + if not os.path.isabs(path): + path = os.path.join(os.getcwd(), path) + return path + + +def replace_env_vars_in_string(s: str) -> str: + """ + Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable. + If the environment variable is not set, raise an error. + """ + + def replacer(match): + var_name = match.group(1) + value = os.environ.get(var_name) + + if value is None: + raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.") + + return value + + return re.sub(r'\$\{([^}]+)\}', replacer, s) + + +def preprocess_config(config: OrderedDict, name: str = None): + if "job" not in config: + raise ValueError("config file must have a job key") + if "config" not in config: + raise ValueError("config file must have a config section") + if "name" not in config["config"] and name is None: + raise ValueError("config file must have a config.name key") + # we need to replace tags. For now just [name] + if name is None: + name = config["config"]["name"] + config_string = json.dumps(config) + config_string = config_string.replace("[name]", name) + config = json.loads(config_string, object_pairs_hook=OrderedDict) + return config + + +# Fixes issue where yaml doesnt load exponents correctly +fixed_loader = yaml.SafeLoader +fixed_loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + + +def get_config( + config_file_path_or_dict: Union[str, dict, OrderedDict], + name=None +): + # if we got a dict, process it and return it + if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict): + config = config_file_path_or_dict + return preprocess_config(config, name) + + config_file_path = config_file_path_or_dict + + # first check if it is in the config folder + config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) + # see if it is in the config folder with any of the possible extensions if it doesnt have one + real_config_path = None + if not os.path.exists(config_path): + for ext in possible_extensions: + if os.path.exists(config_path + ext): + real_config_path = config_path + ext + break + + # if we didn't find it there, check if it is a full path + if not real_config_path: + if os.path.exists(config_file_path): + real_config_path = config_file_path + elif os.path.exists(get_cwd_abs_path(config_file_path)): + real_config_path = get_cwd_abs_path(config_file_path) + + if not real_config_path: + raise ValueError(f"Could not find config file {config_file_path}") + + # if we found it, check if it is a json or yaml file + with open(real_config_path, 'r', encoding='utf-8') as f: + content = f.read() + content_with_env_replaced = replace_env_vars_in_string(content) + if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'): + config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict) + elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): + config = yaml.load(content_with_env_replaced, Loader=fixed_loader) + else: + raise ValueError(f"Config file {config_file_path} must be a json or yaml file") + + return preprocess_config(config, name) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7215bf76e7fef5daee5c0000b6c1053608745d --- /dev/null +++ b/toolkit/config_modules.py @@ -0,0 +1,927 @@ +import os +import time +from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict +import random + +import torch + +from toolkit.prompt_utils import PromptEmbeds + +ImgExt = Literal['jpg', 'png', 'webp'] + +SaveFormat = Literal['safetensors', 'diffusers'] + +if TYPE_CHECKING: + from toolkit.guidance import GuidanceType + from toolkit.logging import EmptyLogger +else: + EmptyLogger = None + +class SaveConfig: + def __init__(self, **kwargs): + self.save_every: int = kwargs.get('save_every', 1000) + self.dtype: str = kwargs.get('dtype', 'float16') + self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5) + self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors') + if self.save_format not in ['safetensors', 'diffusers']: + raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}") + self.push_to_hub: bool = kwargs.get("push_to_hub", False) + self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None) + self.hf_private: Optional[str] = kwargs.get("hf_private", False) + +class LoggingConfig: + def __init__(self, **kwargs): + self.log_every: int = kwargs.get('log_every', 100) + self.verbose: bool = kwargs.get('verbose', False) + self.use_wandb: bool = kwargs.get('use_wandb', False) + self.project_name: str = kwargs.get('project_name', 'ai-toolkit') + self.run_name: str = kwargs.get('run_name', None) + + +class SampleConfig: + def __init__(self, **kwargs): + self.sampler: str = kwargs.get('sampler', 'ddpm') + self.sample_every: int = kwargs.get('sample_every', 100) + self.width: int = kwargs.get('width', 512) + self.height: int = kwargs.get('height', 512) + self.prompts: list[str] = kwargs.get('prompts', []) + self.neg = kwargs.get('neg', False) + self.seed = kwargs.get('seed', 0) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.network_multiplier = kwargs.get('network_multiplier', 1) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext: ImgExt = kwargs.get('format', 'jpg') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + self.refiner_start_at = kwargs.get('refiner_start_at', + 0.5) # step to start using refiner on sample if it exists + self.extra_values = kwargs.get('extra_values', []) + + +class LormModuleSettingsConfig: + def __init__(self, **kwargs): + self.contains: str = kwargs.get('contains', '4nt$3') + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + # min num parameters to attach to + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + + +class LoRMConfig: + def __init__(self, **kwargs): + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + self.do_conv: bool = kwargs.get('do_conv', False) + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + module_settings = kwargs.get('module_settings', []) + default_module_settings = { + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + } + module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings] + self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for + module_setting in module_settings] + + def get_config_for_module(self, block_name): + for setting in self.module_settings: + contain_pieces = setting.contains.split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # try replacing the . with _ + contain_pieces = setting.contains.replace('.', '_').split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # do default + return LormModuleSettingsConfig(**{ + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + }) + + +NetworkType = Literal['lora', 'locon', 'lorm'] + + +class NetworkConfig: + def __init__(self, **kwargs): + self.type: NetworkType = kwargs.get('type', 'lora') + rank = kwargs.get('rank', None) + linear = kwargs.get('linear', None) + if rank is not None: + self.rank: int = rank # rank for backward compatibility + self.linear: int = rank + elif linear is not None: + self.rank: int = linear + self.linear: int = linear + self.conv: int = kwargs.get('conv', None) + self.alpha: float = kwargs.get('alpha', 1.0) + self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) + self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) + self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.network_kwargs: dict = kwargs.get('network_kwargs', {}) + + self.lorm_config: Union[LoRMConfig, None] = None + lorm = kwargs.get('lorm', None) + if lorm is not None: + self.lorm_config: LoRMConfig = LoRMConfig(**lorm) + + if self.type == 'lorm': + # set linear to arbitrary values so it makes them + self.linear = 4 + self.rank = 4 + if self.lorm_config.do_conv: + self.conv = 4 + + self.transformer_only = kwargs.get('transformer_only', True) + + +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net'] + +CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] + + +class AdapterConfig: + def __init__(self, **kwargs): + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net + self.in_channels: int = kwargs.get('in_channels', 3) + self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) + self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) + self.downscale_factor: int = kwargs.get('downscale_factor', 8) + self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') + self.image_dir: str = kwargs.get('image_dir', None) + self.test_img_path: str = kwargs.get('test_img_path', None) + self.train: str = kwargs.get('train', False) + self.image_encoder_path: str = kwargs.get('image_encoder_path', None) + self.name_or_path = kwargs.get('name_or_path', None) + + num_tokens = kwargs.get('num_tokens', None) + if num_tokens is None and self.type.startswith('ip'): + if self.type == 'ip+': + num_tokens = 16 + num_tokens = 16 + elif self.type == 'ip': + num_tokens = 4 + + self.num_tokens: int = num_tokens + self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) + self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False) + if self.train_only_image_encoder: + self.train_image_encoder = True + self.train_only_image_encoder_positional_embedding: bool = kwargs.get( + 'train_only_image_encoder_positional_embedding', False) + self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe + self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) + self.safe_channels: int = kwargs.get('safe_channels', 2048) + self.safe_tokens: int = kwargs.get('safe_tokens', 8) + self.quad_image: bool = kwargs.get('quad_image', False) + + # clip vision + self.trigger = kwargs.get('trigger', 'tri993r') + self.trigger_class_name = kwargs.get('trigger_class_name', None) + + self.class_names = kwargs.get('class_names', []) + + self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None) + if self.clip_layer is None: + if self.type.startswith('ip+'): + self.clip_layer = 'penultimate_hidden_states' + else: + self.clip_layer = 'last_hidden_state' + + # text encoder + self.text_encoder_path: str = kwargs.get('text_encoder_path', None) + self.text_encoder_arch: str = kwargs.get('text_encoder_arch', 'clip') # clip t5 + + self.train_scaler: bool = kwargs.get('train_scaler', False) + self.scaler_lr: Optional[float] = kwargs.get('scaler_lr', None) + + # trains with a scaler to easy channel bias but merges it in on save + self.merge_scaler: bool = kwargs.get('merge_scaler', False) + + # for ilora + self.head_dim: int = kwargs.get('head_dim', 1024) + self.num_heads: int = kwargs.get('num_heads', 1) + self.ilora_down: bool = kwargs.get('ilora_down', True) + self.ilora_mid: bool = kwargs.get('ilora_mid', True) + self.ilora_up: bool = kwargs.get('ilora_up', True) + + self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512) + self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False) + + self.flux_only_double: bool = kwargs.get('flux_only_double', False) + + # train and use a conv layer to pool the embedding + self.conv_pooling: bool = kwargs.get('conv_pooling', False) + self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1) + self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None) + + +class EmbeddingConfig: + def __init__(self, **kwargs): + self.trigger = kwargs.get('trigger', 'custom_embedding') + self.tokens = kwargs.get('tokens', 4) + self.init_words = kwargs.get('init_words', '*') + self.save_format = kwargs.get('save_format', 'safetensors') + self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior + + +class DecoratorConfig: + def __init__(self, **kwargs): + self.num_tokens: str = kwargs.get('num_tokens', 4) + + +ContentOrStyleType = Literal['balanced', 'style', 'content'] +LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] + + +class TrainConfig: + def __init__(self, **kwargs): + self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') + self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.steps: int = kwargs.get('steps', 1000) + self.lr = kwargs.get('lr', 1e-6) + self.unet_lr = kwargs.get('unet_lr', self.lr) + self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) + self.refiner_lr = kwargs.get('refiner_lr', self.lr) + self.embedding_lr = kwargs.get('embedding_lr', self.lr) + self.adapter_lr = kwargs.get('adapter_lr', self.lr) + self.optimizer = kwargs.get('optimizer', 'adamw') + self.optimizer_params = kwargs.get('optimizer_params', {}) + self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') + self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) + self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) + self.batch_size: int = kwargs.get('batch_size', 1) + self.orig_batch_size: int = self.batch_size + self.dtype: str = kwargs.get('dtype', 'fp32') + self.xformers = kwargs.get('xformers', False) + self.sdp = kwargs.get('sdp', False) + self.train_unet = kwargs.get('train_unet', True) + self.train_text_encoder = kwargs.get('train_text_encoder', False) + self.train_refiner = kwargs.get('train_refiner', True) + self.train_turbo = kwargs.get('train_turbo', False) + self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False) + self.min_snr_gamma = kwargs.get('min_snr_gamma', None) + self.snr_gamma = kwargs.get('snr_gamma', None) + # trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials + # this should balance the learning rate across all timesteps over time + self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False) + self.noise_offset = kwargs.get('noise_offset', 0.0) + self.skip_first_sample = kwargs.get('skip_first_sample', False) + self.force_first_sample = kwargs.get('force_first_sample', False) + self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) + self.weight_jitter = kwargs.get('weight_jitter', 0.0) + self.merge_network_on_save = kwargs.get('merge_network_on_save', False) + self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) + self.start_step = kwargs.get('start_step', None) + self.free_u = kwargs.get('free_u', False) + self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) + self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net + self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) + self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) + self.img_multiplier = kwargs.get('img_multiplier', 1.0) + self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) + self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) + self.negative_prompt = kwargs.get('negative_prompt', None) + self.max_negative_prompts = kwargs.get('max_negative_prompts', 1) + # multiplier applied to loos on regularization images + self.reg_weight = kwargs.get('reg_weight', 1.0) + self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) + self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) + # automatically adapte the vae scaling based on the image norm + self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False) + + # dropout that happens before encoding. It functions independently per text encoder + self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) + + # match the norm of the noise before computing loss. This will help the model maintain its + # current understandin of the brightness of images. + + self.match_noise_norm = kwargs.get('match_noise_norm', False) + + # set to -1 to accumulate gradients for entire epoch + # warning, only do this with a small dataset or you will run out of memory + # This is legacy but left in for backwards compatibility + self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) + + # this will do proper gradient accumulation where you will not see a step until the end of the accumulation + # the method above will show a step every accumulation + self.gradient_accumulation = kwargs.get('gradient_accumulation', 1) + if self.gradient_accumulation > 1: + if self.gradient_accumulation_steps != 1: + raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive") + + # short long captions will double your batch size. This only works when a dataset is + # prepared with a json caption file that has both short and long captions in it. It will + # Double up every image and run it through with both short and long captions. The idea + # is that the network will learn how to generate good images with both short and long captions + self.short_and_long_captions = kwargs.get('short_and_long_captions', False) + # if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only + self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False) + + # basically gradient accumulation but we run just 1 item through the network + # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful + # for training tricks that increase batch size but need a single gradient step + self.single_item_batching = kwargs.get('single_item_batching', False) + + match_adapter_assist = kwargs.get('match_adapter_assist', False) + self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) + self.loss_target: LossTarget = kwargs.get('loss_target', + 'noise') # noise, source, unaugmented, differential_noise + + # When a mask is passed in a dataset, and this is true, + # we will predict noise without a the LoRa network and use the prediction as a target for + # unmasked reign. It is unmasked regularization basically + self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False) + self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5) + + # legacy + if match_adapter_assist and self.match_adapter_chance == 0.0: + self.match_adapter_chance = 1.0 + + # standardize inputs to the meand std of the model knowledge + self.standardize_images = kwargs.get('standardize_images', False) + self.standardize_latents = kwargs.get('standardize_latents', False) + + if self.train_turbo and not self.noise_scheduler.startswith("euler"): + raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") + + self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) + self.do_cfg = kwargs.get('do_cfg', False) + self.do_random_cfg = kwargs.get('do_random_cfg', False) + self.cfg_scale = kwargs.get('cfg_scale', 1.0) + self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale) + self.cfg_rescale = kwargs.get('cfg_rescale', None) + if self.cfg_rescale is None: + self.cfg_rescale = self.cfg_scale + + # applies the inverse of the prediction mean and std to the target to correct + # for norm drift + self.correct_pred_norm = kwargs.get('correct_pred_norm', False) + self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) + + self.loss_type = kwargs.get('loss_type', 'mse') + + # scale the prediction by this. Increase for more detail, decrease for less + self.pred_scaler = kwargs.get('pred_scaler', 1.0) + + # repeats the prompt a few times to saturate the encoder + self.prompt_saturation_chance = kwargs.get('prompt_saturation_chance', 0.0) + + # applies negative loss on the prior to encourage network to diverge from it + self.do_prior_divergence = kwargs.get('do_prior_divergence', False) + + ema_config: Union[Dict, None] = kwargs.get('ema_config', None) + if ema_config is not None: + ema_config['use_ema'] = True + print(f"Using EMA") + else: + ema_config = {'use_ema': False} + + self.ema_config: EMAConfig = EMAConfig(**ema_config) + + # adds an additional loss to the network to encourage it output a normalized standard deviation + self.target_norm_std = kwargs.get('target_norm_std', None) + self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear + self.linear_timesteps = kwargs.get('linear_timesteps', False) + self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) + self.disable_sampling = kwargs.get('disable_sampling', False) + + # will cache a blank prompt or the trigger word, and unload the text encoder to cpu + # will make training faster and use less vram + self.unload_text_encoder = kwargs.get('unload_text_encoder', False) + # for swapping which parameters are trained during training + self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) + # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more + self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1) + # bypass the guidance embedding for training. For open flux with guidance embedding + self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) + + +class ModelConfig: + def __init__(self, **kwargs): + self.name_or_path: str = kwargs.get('name_or_path', None) + # name or path is updated on fine tuning. Keep a copy of the original + self.name_or_path_original: str = self.name_or_path + self.is_v2: bool = kwargs.get('is_v2', False) + self.is_xl: bool = kwargs.get('is_xl', False) + self.is_pixart: bool = kwargs.get('is_pixart', False) + self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) + self.is_auraflow: bool = kwargs.get('is_auraflow', False) + self.is_v3: bool = kwargs.get('is_v3', False) + self.is_flux: bool = kwargs.get('is_flux', False) + if self.is_pixart_sigma: + self.is_pixart = True + self.use_flux_cfg = kwargs.get('use_flux_cfg', False) + self.is_ssd: bool = kwargs.get('is_ssd', False) + self.is_vega: bool = kwargs.get('is_vega', False) + self.is_v_pred: bool = kwargs.get('is_v_pred', False) + self.dtype: str = kwargs.get('dtype', 'float16') + self.vae_path = kwargs.get('vae_path', None) + self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None) + self._original_refiner_name_or_path = self.refiner_name_or_path + self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) + self.lora_path = kwargs.get('lora_path', None) + # mainly for decompression loras for distilled models + self.assistant_lora_path = kwargs.get('assistant_lora_path', None) + self.inference_lora_path = kwargs.get('inference_lora_path', None) + self.latent_space_version = kwargs.get('latent_space_version', None) + + # only for SDXL models for now + self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) + self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True) + + self.experimental_xl: bool = kwargs.get('experimental_xl', False) + + if self.name_or_path is None: + raise ValueError('name_or_path must be specified') + + if self.is_ssd: + # sed sdxl as true since it is mostly the same architecture + self.is_xl = True + + if self.is_vega: + self.is_xl = True + + # for text encoder quant. Only works with pixart currently + self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4 + self.unet_path = kwargs.get("unet_path", None) + self.unet_sample_size = kwargs.get("unet_sample_size", None) + self.vae_device = kwargs.get("vae_device", None) + self.vae_dtype = kwargs.get("vae_dtype", self.dtype) + self.te_device = kwargs.get("te_device", None) + self.te_dtype = kwargs.get("te_dtype", self.dtype) + + # only for flux for now + self.quantize = kwargs.get("quantize", False) + self.low_vram = kwargs.get("low_vram", False) + self.attn_masking = kwargs.get("attn_masking", False) + if self.attn_masking and not self.is_flux: + raise ValueError("attn_masking is only supported with flux models currently") + # for targeting a specific layers + self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None) + self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None) + self.quantize_kwargs = kwargs.get("quantize_kwargs", {}) + + if self.ignore_if_contains is not None or self.only_if_contains is not None: + if not self.is_flux: + raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently") + + +class EMAConfig: + def __init__(self, **kwargs): + self.use_ema: bool = kwargs.get('use_ema', False) + self.ema_decay: float = kwargs.get('ema_decay', 0.999) + # feeds back the decay difference into the parameter + self.use_feedback: bool = kwargs.get('use_feedback', False) + + # every update, the params are multiplied by this amount + # only use for things without a bias like lora + # similar to a decay in an optimizer but the opposite + self.param_multiplier: float = kwargs.get('param_multiplier', 1.0) + + +class ReferenceDatasetConfig: + def __init__(self, **kwargs): + # can pass with a side by side pait or a folder with pos and neg folder + self.pair_folder: str = kwargs.get('pair_folder', None) + self.pos_folder: str = kwargs.get('pos_folder', None) + self.neg_folder: str = kwargs.get('neg_folder', None) + + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight)) + self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight)) + # make sure they are all absolute values no negatives + self.pos_weight = abs(self.pos_weight) + self.neg_weight = abs(self.neg_weight) + + self.target_class: str = kwargs.get('target_class', '') + self.size: int = kwargs.get('size', 512) + + +class SliderTargetConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', '') + self.positive: str = kwargs.get('positive', '') + self.negative: str = kwargs.get('negative', '') + self.multiplier: float = kwargs.get('multiplier', 1.0) + self.weight: float = kwargs.get('weight', 1.0) + self.shuffle: bool = kwargs.get('shuffle', False) + + +class GuidanceConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', '') + self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) + self.positive_prompt: str = kwargs.get('positive_prompt', '') + self.negative_prompt: str = kwargs.get('negative_prompt', '') + + +class SliderConfigAnchors: + def __init__(self, **kwargs): + self.prompt = kwargs.get('prompt', '') + self.neg_prompt = kwargs.get('neg_prompt', '') + self.multiplier = kwargs.get('multiplier', 1.0) + + +class SliderConfig: + def __init__(self, **kwargs): + targets = kwargs.get('targets', []) + anchors = kwargs.get('anchors', []) + anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] + self.anchors: List[SliderConfigAnchors] = anchors + self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) + self.prompt_file: str = kwargs.get('prompt_file', None) + self.prompt_tensors: str = kwargs.get('prompt_tensors', None) + self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) + self.use_adapter: bool = kwargs.get('use_adapter', None) # depth + self.adapter_img_dir = kwargs.get('adapter_img_dir', None) + self.low_ram = kwargs.get('low_ram', False) + + # expand targets if shuffling + from toolkit.prompt_utils import get_slider_target_permutations + self.targets: List[SliderTargetConfig] = [] + targets = [SliderTargetConfig(**target) for target in targets] + # do permutations if shuffle is true + print(f"Building slider targets") + for target in targets: + if target.shuffle: + target_permutations = get_slider_target_permutations(target, max_permutations=8) + self.targets = self.targets + target_permutations + else: + self.targets.append(target) + print(f"Built {len(self.targets)} slider targets (with permutations)") + + +class DatasetConfig: + """ + Dataset config for sd-datasets + + """ + + def __init__(self, **kwargs): + self.type = kwargs.get('type', 'image') # sd, slider, reference + # will be legacy + self.folder_path: str = kwargs.get('folder_path', None) + # can be json or folder path + self.dataset_path: str = kwargs.get('dataset_path', None) + + self.default_caption: str = kwargs.get('default_caption', None) + # trigger word for just this dataset + self.trigger_word: str = kwargs.get('trigger_word', None) + random_triggers = kwargs.get('random_triggers', []) + # if they are a string, load them from a file + if isinstance(random_triggers, str) and os.path.exists(random_triggers): + with open(random_triggers, 'r') as f: + random_triggers = f.read().splitlines() + # remove empty lines + random_triggers = [line for line in random_triggers if line.strip() != ''] + self.random_triggers: List[str] = random_triggers + self.random_triggers_max: int = kwargs.get('random_triggers_max', 1) + self.caption_ext: str = kwargs.get('caption_ext', None) + self.random_scale: bool = kwargs.get('random_scale', False) + self.random_crop: bool = kwargs.get('random_crop', False) + self.resolution: int = kwargs.get('resolution', 512) + self.scale: float = kwargs.get('scale', 1.0) + self.buckets: bool = kwargs.get('buckets', True) + self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) + self.is_reg: bool = kwargs.get('is_reg', False) + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) + self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) + self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped + self.flip_x: bool = kwargs.get('flip_x', False) + self.flip_y: bool = kwargs.get('flip_y', False) + self.augments: List[str] = kwargs.get('augments', []) + self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) + self.full_size_control_images: bool = kwargs.get('full_size_control_images', False) + self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask + self.mask_path: str = kwargs.get('mask_path', + None) # focus mask (black and white. White has higher loss than black) + self.unconditional_path: str = kwargs.get('unconditional_path', + None) # path where matching unconditional images are located + self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask + self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 + self.poi: Union[str, None] = kwargs.get('poi', + None) # if one is set and in json data, will be used as auto crop scale point of interes + self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset + # cache latents will store them in memory + self.cache_latents: bool = kwargs.get('cache_latents', False) + # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory + self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False) + + self.standardize_images: bool = kwargs.get('standardize_images', False) + + # https://albumentations.ai/docs/api_reference/augmentations/transforms + # augmentations are returned as a separate image and cannot currently be cached + self.augmentations: List[dict] = kwargs.get('augmentations', None) + self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False) + + has_augmentations = self.augmentations is not None and len(self.augmentations) > 0 + + if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk): + print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") + self.cache_latents = False + self.cache_latents_to_disk = False + + # legacy compatability + legacy_caption_type = kwargs.get('caption_type', None) + if legacy_caption_type: + self.caption_ext = legacy_caption_type + self.caption_type = self.caption_ext + self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted') + + # ip adapter / reference dataset + self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc + # get the clip image randomly from the same folder as the image. Useful for folder grouped pairs. + self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False) + self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None) + self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False) + self.replacements: List[str] = kwargs.get('replacements', []) + self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0) + + self.num_workers: int = kwargs.get('num_workers', 2) + self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) + self.extra_values: List[float] = kwargs.get('extra_values', []) + self.square_crop: bool = kwargs.get('square_crop', False) + # apply same augmentations to control images. Usually want this true unless special case + self.replay_transforms: bool = kwargs.get('replay_transforms', True) + + +def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: + """ + This just splits up the datasets by resolutions so you dont have to do it manually + :param raw_config: + :return: + """ + # split up datasets by resolutions + new_config = [] + for dataset in raw_config: + resolution = dataset.get('resolution', 512) + if isinstance(resolution, list): + resolution_list = resolution + else: + resolution_list = [resolution] + for res in resolution_list: + dataset_copy = dataset.copy() + dataset_copy['resolution'] = res + new_config.append(dataset_copy) + return new_config + + +class GenerateImageConfig: + def __init__( + self, + prompt: str = '', + prompt_2: Optional[str] = None, + width: int = 512, + height: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str = '', + negative_prompt_2: Optional[str] = None, + seed: int = -1, + network_multiplier: float = 1.0, + guidance_rescale: float = 0.0, + # the tag [time] will be replaced with milliseconds since epoch + output_path: str = None, # full image path + output_folder: str = None, # folder to save image in if output_path is not specified + output_ext: str = ImgExt, # extension to save image as if output_path is not specified + output_tail: str = '', # tail to add to output filename + add_prompt_file: bool = False, # add a prompt file with generated image + adapter_image_path: str = None, # path to adapter image + adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning + latents: Union[torch.Tensor | None] = None, # input latent to start with, + extra_kwargs: dict = None, # extra data to save with prompt file + refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end + extra_values: List[float] = None, # extra values to save with prompt file + logger: Optional[EmptyLogger] = None, + ): + self.width: int = width + self.height: int = height + self.num_inference_steps: int = num_inference_steps + self.guidance_scale: float = guidance_scale + self.guidance_rescale: float = guidance_rescale + self.prompt: str = prompt + self.prompt_2: str = prompt_2 + self.negative_prompt: str = negative_prompt + self.negative_prompt_2: str = negative_prompt_2 + self.latents: Union[torch.Tensor | None] = latents + + self.output_path: str = output_path + self.seed: int = seed + if self.seed == -1: + # generate random one + self.seed = random.randint(0, 2 ** 32 - 1) + self.network_multiplier: float = network_multiplier + self.output_folder: str = output_folder + self.output_ext: str = output_ext + self.add_prompt_file: bool = add_prompt_file + self.output_tail: str = output_tail + self.gen_time: int = int(time.time() * 1000) + self.adapter_image_path: str = adapter_image_path + self.adapter_conditioning_scale: float = adapter_conditioning_scale + self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {} + self.refiner_start_at = refiner_start_at + self.extra_values = extra_values if extra_values is not None else [] + + # prompt string will override any settings above + self._process_prompt_string() + + # handle dual text encoder prompts if nothing passed + if negative_prompt_2 is None: + self.negative_prompt_2 = negative_prompt + + if prompt_2 is None: + self.prompt_2 = self.prompt + + # parse prompt paths + if self.output_path is None and self.output_folder is None: + raise ValueError('output_path or output_folder must be specified') + elif self.output_path is not None: + self.output_folder = os.path.dirname(self.output_path) + self.output_ext = os.path.splitext(self.output_path)[1][1:] + self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0] + + else: + self.output_filename_no_ext = '[time]_[count]' + if len(self.output_tail) > 0: + self.output_filename_no_ext += '_' + self.output_tail + self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext) + + # adjust height + self.height = max(64, self.height - self.height % 8) # round to divisible by 8 + self.width = max(64, self.width - self.width % 8) # round to divisible by 8 + + self.logger = logger + + def set_gen_time(self, gen_time: int = None): + if gen_time is not None: + self.gen_time = gen_time + else: + self.gen_time = int(time.time() * 1000) + + def _get_path_no_ext(self, count: int = 0, max_count=0): + # zero pad count + count_str = str(count).zfill(len(str(max_count))) + # replace [time] with gen time + filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time)) + # replace [count] with count + filename = filename.replace('[count]', count_str) + return filename + + def get_image_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + ext = self.output_ext + # if it does not start with a dot add one + if ext[0] != '.': + ext = '.' + ext + filename += ext + # join with folder + return os.path.join(self.output_folder, filename) + + def get_prompt_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + filename += '.txt' + # join with folder + return os.path.join(self.output_folder, filename) + + def save_image(self, image, count: int = 0, max_count=0): + # make parent dirs + os.makedirs(self.output_folder, exist_ok=True) + self.set_gen_time() + # TODO save image gen header info for A1111 and us, our seeds probably wont match + image.save(self.get_image_path(count, max_count)) + # do prompt file + if self.add_prompt_file: + self.save_prompt_file(count, max_count) + + def save_prompt_file(self, count: int = 0, max_count=0): + # save prompt file + with open(self.get_prompt_path(count, max_count), 'w') as f: + prompt = self.prompt + if self.prompt_2 is not None: + prompt += ' --p2 ' + self.prompt_2 + if self.negative_prompt is not None: + prompt += ' --n ' + self.negative_prompt + if self.negative_prompt_2 is not None: + prompt += ' --n2 ' + self.negative_prompt_2 + prompt += ' --w ' + str(self.width) + prompt += ' --h ' + str(self.height) + prompt += ' --seed ' + str(self.seed) + prompt += ' --cfg ' + str(self.guidance_scale) + prompt += ' --steps ' + str(self.num_inference_steps) + prompt += ' --m ' + str(self.network_multiplier) + prompt += ' --gr ' + str(self.guidance_rescale) + + # get gen info + f.write(self.prompt) + + def _process_prompt_string(self): + # we will try to support all sd-scripts where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --m Specify the network multiplier for the generated image. + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + # --network_multiplier Specify the network multiplier for the generated image same as --m + + # process prompt string and update values if it has some + if self.prompt is not None and len(self.prompt) > 0: + # process prompt string + prompt = self.prompt + prompt = prompt.strip() + p_split = prompt.split('--') + self.prompt = p_split[0].strip() + + if len(p_split) > 1: + for split in p_split[1:]: + # allows multi char flags + flag = split.split(' ')[0].strip() + content = split[len(flag):].strip() + if flag == 'p2': + self.prompt_2 = content + elif flag == 'n': + self.negative_prompt = content + elif flag == 'n2': + self.negative_prompt_2 = content + elif flag == 'w': + self.width = int(content) + elif flag == 'h': + self.height = int(content) + elif flag == 'd': + self.seed = int(content) + elif flag == 'seed': + self.seed = int(content) + elif flag == 'l': + self.guidance_scale = float(content) + elif flag == 'cfg': + self.guidance_scale = float(content) + elif flag == 's': + self.num_inference_steps = int(content) + elif flag == 'steps': + self.num_inference_steps = int(content) + elif flag == 'm': + self.network_multiplier = float(content) + elif flag == 'network_multiplier': + self.network_multiplier = float(content) + elif flag == 'gr': + self.guidance_rescale = float(content) + elif flag == 'a': + self.adapter_conditioning_scale = float(content) + elif flag == 'ref': + self.refiner_start_at = float(content) + elif flag == 'ev': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] + elif flag == 'extra_values': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] + + def post_process_embeddings( + self, + conditional_prompt_embeds: PromptEmbeds, + unconditional_prompt_embeds: Optional[PromptEmbeds] = None, + ): + # this is called after prompt embeds are encoded. We can override them in the future here + pass + + def log_image(self, image, count: int = 0, max_count=0): + if self.logger is None: + return + + self.logger.log_image(image, count, self.prompt) + + +def validate_configs( + train_config: TrainConfig, + model_config: ModelConfig, + save_config: SaveConfig, +): + if model_config.is_flux: + if save_config.save_format != 'diffusers': + # make it diffusers + save_config.save_format = 'diffusers' + if model_config.use_flux_cfg: + # bypass the embedding + train_config.bypass_guidance_embedding = True diff --git a/toolkit/cuda_malloc.py b/toolkit/cuda_malloc.py new file mode 100644 index 0000000000000000000000000000000000000000..239b9666a83ea3f3838737b725902c6590ea19bc --- /dev/null +++ b/toolkit/cuda_malloc.py @@ -0,0 +1,93 @@ +# ref comfy ui +import os +import importlib.util + + +# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + + return enum_display_devices() + else: + return set() + + +blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", + "GeForce 945M", + "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", + "Quadro K620", + "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", + "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", + "Quadro M5500", "Quadro M6000", + "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", + "GeForce GTX 1650", "GeForce GTX 1630" + } + + +def cuda_malloc_supported(): + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +cuda_malloc = False + +if not cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: # enable by default for torch version 2.0 and up + cuda_malloc = cuda_malloc_supported() + except: + pass + +if cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + print("CUDA Malloc Async Enabled") diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..12a4df4be4aa406a5bb08bbd0ae5495363490da7 --- /dev/null +++ b/toolkit/custom_adapter.py @@ -0,0 +1,1026 @@ +import math +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, CLIPTextModel, \ + CLIPTokenizer, T5Tokenizer + +from toolkit.models.clip_fusion import CLIPFusionModule +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.ilora import InstantLoRAModule +from toolkit.models.single_value_adapter import SingleValueAdapter +from toolkit.models.te_adapter import TEAdapter +from toolkit.models.te_aug_adapter import TEAugAdapter +from toolkit.models.vd_adapter import VisionDirectAdapter +from toolkit.models.redux import ReduxImageEncoder +from toolkit.paths import REPOS_ROOT +from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder +from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model +from toolkit.train_tools import get_torch_dtype +from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible +import random + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import Resampler +from toolkit.config_modules import AdapterConfig, AdapterTypes +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel, + AutoImageProcessor, + ConvNextModel, + ConvNextForImageClassification, + ConvNextImageProcessor, + UMT5EncoderModel, LlamaTokenizerFast +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification + +import torch.nn.functional as F + + +class CustomAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.image_processor: CLIPImageProcessor = None + self.input_size = 224 + self.adapter_type: AdapterTypes = self.config.type + self.current_scale = 1.0 + self.is_active = True + self.flag_word = "fla9wor0" + self.is_unconditional_run = False + + self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None + + self.fuse_module: FuseModule = None + + self.lora: None = None + + self.position_ids: Optional[List[int]] = None + + self.num_control_images = 1 + self.token_mask: Optional[torch.Tensor] = None + + # setup clip + self.setup_clip() + # add for dataloader + self.clip_image_processor = self.image_processor + + self.clip_fusion_module: CLIPFusionModule = None + self.ilora_module: InstantLoRAModule = None + + self.te: Union[T5EncoderModel, CLIPTextModel] = None + self.tokenizer: CLIPTokenizer = None + self.te_adapter: TEAdapter = None + self.te_augmenter: TEAugAdapter = None + self.vd_adapter: VisionDirectAdapter = None + self.single_value_adapter: SingleValueAdapter = None + self.redux_adapter: ReduxImageEncoder = None + + self.conditional_embeds: Optional[torch.Tensor] = None + self.unconditional_embeds: Optional[torch.Tensor] = None + + self.setup_adapter() + + if self.adapter_type == 'photo_maker': + # try to load from our name_or_path + if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): + self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) + # add the trigger word to the tokenizer + if isinstance(self.sd_ref().tokenizer, list): + for tokenizer in self.sd_ref().tokenizer: + tokenizer.add_tokens([self.flag_word], special_tokens=True) + else: + self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) + elif self.config.name_or_path is not None: + loaded_state_dict = load_custom_adapter_model( + self.config.name_or_path, + self.sd_ref().device, + dtype=self.sd_ref().dtype, + ) + self.load_state_dict(loaded_state_dict, strict=False) + + def setup_adapter(self): + torch_dtype = get_torch_dtype(self.sd_ref().dtype) + if self.adapter_type == 'photo_maker': + sd = self.sd_ref() + embed_dim = sd.unet.config['cross_attention_dim'] + self.fuse_module = FuseModule(embed_dim) + elif self.adapter_type == 'clip_fusion': + sd = self.sd_ref() + embed_dim = sd.unet.config['cross_attention_dim'] + + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + self.clip_fusion_module = CLIPFusionModule( + text_hidden_size=embed_dim, + text_tokens=77, + vision_hidden_size=self.vision_encoder.config.hidden_size, + vision_tokens=vision_tokens + ) + elif self.adapter_type == 'ilora': + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + + vision_hidden_size = self.vision_encoder.config.hidden_size + + if self.config.clip_layer == 'image_embeds': + vision_tokens = 1 + vision_hidden_size = self.vision_encoder.config.projection_dim + + self.ilora_module = InstantLoRAModule( + vision_tokens=vision_tokens, + vision_hidden_size=vision_hidden_size, + head_dim=self.config.head_dim, + num_heads=self.config.num_heads, + sd=self.sd_ref(), + config=self.config + ) + elif self.adapter_type == 'text_encoder': + if self.config.text_encoder_arch == 't5': + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = T5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None + self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path) + elif self.config.text_encoder_arch == 'pile-t5': + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = UMT5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None + self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path) + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + elif self.config.text_encoder_arch == 'clip': + self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device, + dtype=torch_dtype) + self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path) + else: + raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}") + + self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer) + elif self.adapter_type == 'te_augmenter': + self.te_augmenter = TEAugAdapter(self, self.sd_ref()) + elif self.adapter_type == 'vision_direct': + self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder) + elif self.adapter_type == 'single_value': + self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens) + elif self.adapter_type == 'redux': + vision_hidden_size = self.vision_encoder.config.hidden_size + self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) + else: + raise ValueError(f"unknown adapter type: {self.adapter_type}") + + def forward(self, *args, **kwargs): + # dont think this is used + # if self.adapter_type == 'photo_maker': + # id_pixel_values = args[0] + # prompt_embeds: PromptEmbeds = args[1] + # class_tokens_mask = args[2] + # + # grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() + # + # with torch.set_grad_enabled(grads_on_image_encoder): + # id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) + # + # if not grads_on_image_encoder: + # id_embeds = id_embeds.detach() + # + # prompt_embeds = prompt_embeds.detach() + # + # updated_prompt_embeds = self.fuse_module( + # prompt_embeds, id_embeds, class_tokens_mask + # ) + # + # return updated_prompt_embeds + # else: + raise NotImplementedError + + def setup_clip(self): + adapter_config = self.config + sd = self.sd_ref() + if self.config.type == "text_encoder" or self.config.type == "single_value": + return + if self.config.type == 'photo_maker': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + if self.config.image_encoder_path is None: + self.vision_encoder = PhotoMakerCLIPEncoder() + else: + self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path) + elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'pixtral': + self.image_processor = PixtralVisionImagePreprocessorCompatible( + max_image_size=self.config.pixtral_max_image_size, + ) + self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained( + adapter_config.image_encoder_path, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit': + try: + self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = ViTFeatureExtractor() + self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SAFEImageProcessor() + self.vision_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.safe_tokens, + num_vectors=sd.unet.config['cross_attention_dim'], + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.vision_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit-hybrid': + try: + self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.image_processor = ViTHybridImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.vision_encoder = ViTHybridForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + self.input_size = self.vision_encoder.config.image_size + + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + if 'height' in self.image_processor.size: + self.image_processor.size['height'] = preprocessor_input_size + self.image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.image_processor, 'crop_size'): + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + + if self.config.image_encoder_arch == 'clip+': + # self.image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 4 + + # update the preprocessor so images come in at the right size + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.vision_encoder.config.image_size, + ) + if 'height' in self.image_processor.size: + self.input_size = self.image_processor.size['height'] + else: + self.input_size = self.image_processor.crop_size['height'] + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict: + # we are loading pure clip weights. + self.vision_encoder.load_state_dict(state_dict, strict=strict) + + if 'lora_weights' in state_dict: + # todo add LoRA + # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + # self.sd_ref().pipeline.fuse_lora() + pass + if 'clip_fusion' in state_dict: + self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict) + if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'): + self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict) + # check to see if the fuse weights are there + fuse_weights = {} + for k, v in state_dict['id_encoder'].items(): + if k.startswith('fuse_module'): + k = k.replace('fuse_module.', '') + fuse_weights[k] = v + if len(fuse_weights) > 0: + try: + self.fuse_module.load_state_dict(fuse_weights, strict=strict) + except Exception as e: + + print(e) + # force load it + print(f"force loading fuse module as it did not match") + current_state_dict = self.fuse_module.state_dict() + for k, v in fuse_weights.items(): + if len(v.shape) == 1: + current_state_dict[k] = v[:current_state_dict[k].shape[0]] + elif len(v.shape) == 2: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]] + elif len(v.shape) == 3: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2]] + elif len(v.shape) == 4: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2], :current_state_dict[k].shape[3]] + else: + raise ValueError(f"unknown shape: {v.shape}") + self.fuse_module.load_state_dict(current_state_dict, strict=strict) + + if 'te_adapter' in state_dict: + self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict) + + if 'te_augmenter' in state_dict: + self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict) + + if 'vd_adapter' in state_dict: + self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict) + if 'dvadapter' in state_dict: + self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False) + + if 'sv_adapter' in state_dict: + self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) + + if 'vision_encoder' in state_dict: + self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) + + if 'fuse_module' in state_dict: + self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) + + if 'ilora' in state_dict: + try: + self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) + except Exception as e: + print(e) + if 'redux_up' in state_dict: + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.redux_adapter.load_state_dict(new_dict, strict=True) + + pass + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.vision_encoder.state_dict() + + if self.adapter_type == 'photo_maker': + if self.config.train_image_encoder: + state_dict["id_encoder"] = self.vision_encoder.state_dict() + + state_dict["fuse_module"] = self.fuse_module.state_dict() + + # todo save LoRA + return state_dict + + elif self.adapter_type == 'clip_fusion': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["clip_fusion"] = self.clip_fusion_module.state_dict() + return state_dict + elif self.adapter_type == 'text_encoder': + state_dict["te_adapter"] = self.te_adapter.state_dict() + return state_dict + elif self.adapter_type == 'te_augmenter': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["te_augmenter"] = self.te_augmenter.state_dict() + return state_dict + elif self.adapter_type == 'vision_direct': + state_dict["dvadapter"] = self.vd_adapter.state_dict() + # if self.config.train_image_encoder: # always return vision encoder + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + return state_dict + elif self.adapter_type == 'single_value': + state_dict["sv_adapter"] = self.single_value_adapter.state_dict() + return state_dict + elif self.adapter_type == 'ilora': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["ilora"] = self.ilora_module.state_dict() + return state_dict + elif self.adapter_type == 'redux': + d = self.redux_adapter.state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + else: + raise NotImplementedError + + def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False): + if self.adapter_type == 'single_value': + if is_unconditional: + self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + else: + self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + + + def condition_prompt( + self, + prompt: Union[List[str], str], + is_unconditional: bool = False, + ): + if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux': + return prompt + elif self.adapter_type == 'text_encoder': + # todo allow for training + with torch.no_grad(): + # encode and save the embeds + if is_unconditional: + self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach() + else: + self.conditional_embeds = self.te_adapter.encode_text(prompt).detach() + return prompt + elif self.adapter_type == 'photo_maker': + if is_unconditional: + return prompt + else: + + with torch.no_grad(): + was_list = isinstance(prompt, list) + if not was_list: + prompt_list = [prompt] + else: + prompt_list = prompt + + new_prompt_list = [] + token_mask_list = [] + + for prompt in prompt_list: + + our_class = None + # find a class in the prompt + prompt_parts = prompt.split(' ') + prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0] + + new_prompt_parts = [] + tokened_prompt_parts = [] + for idx, prompt_part in enumerate(prompt_parts): + new_prompt_parts.append(prompt_part) + tokened_prompt_parts.append(prompt_part) + if prompt_part in self.config.class_names: + our_class = prompt_part + # add the flag word + tokened_prompt_parts.append(self.flag_word) + + if self.num_control_images > 1: + # add the rest + for _ in range(self.num_control_images - 1): + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + # add the rest + tokened_prompt_parts.extend(prompt_parts[idx + 1:]) + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + break + + prompt = " ".join(new_prompt_parts) + tokened_prompt = " ".join(tokened_prompt_parts) + + if our_class is None: + # add the first one to the front of the prompt + tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt + our_class = self.config.class_names[0] + prompt = " ".join( + [self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt + + # add the prompt to the list + new_prompt_list.append(prompt) + + # tokenize them with just the first tokenizer + tokenizer = self.sd_ref().tokenizer + if isinstance(tokenizer, list): + tokenizer = tokenizer[0] + + flag_token = tokenizer.convert_tokens_to_ids(self.flag_word) + + tokenized_prompt = tokenizer.encode(prompt) + tokenized_tokened_prompt = tokenizer.encode(tokened_prompt) + + flag_idx = tokenized_tokened_prompt.index(flag_token) + + class_token = tokenized_prompt[flag_idx - 1] + + boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool) + boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool))) + boolean_mask = boolean_mask.to(self.device) + # zero pad it to 77 + boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False) + + token_mask_list.append(boolean_mask) + + self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device) + + prompt_list = new_prompt_list + + if not was_list: + prompt = prompt_list[0] + else: + prompt = prompt_list + + return prompt + + else: + return prompt + + def condition_encoded_embeds( + self, + tensors_0_1: torch.Tensor, + prompt_embeds: PromptEmbeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=False, + quad_count=4, + is_generating_samples=False, + ) -> PromptEmbeds: + if self.adapter_type == 'text_encoder' and is_generating_samples: + # replace the prompt embed with ours + if is_unconditional: + return self.unconditional_embeds.clone() + return self.conditional_embeds.clone() + + if self.adapter_type == 'ilora': + return prompt_embeds + + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux': + if is_unconditional: + # we dont condition the negative embeds for photo maker + return prompt_embeds.clone() + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + do_convert_rgb=True + ).pixel_values + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + if self.adapter_type == 'photo_maker': + # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image + clip_image = clip_image.unsqueeze(1) + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + do_projection2=isinstance(self.sd_ref().text_encoder, list), + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list) + ).detach() + + prompt_embeds.text_embeds = self.fuse_module( + prompt_embeds.text_embeds, + id_embeds, + self.token_mask + ) + return prompt_embeds + elif self.adapter_type == 'clip_fusion': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + prompt_embeds.text_embeds = self.clip_fusion_module( + prompt_embeds.text_embeds, + img_embeds + ) + return prompt_embeds + + elif self.adapter_type == 'redux': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype))) + + prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2) + return prompt_embeds + else: + return prompt_embeds + + def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor: + with torch.no_grad(): + if shape is None: + shape = [batch_size, 3, self.input_size, self.input_size] + tensors_0_1 = torch.rand(shape, device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + return clip_image.detach() + + def train(self, mode: bool = True): + if self.config.train_image_encoder: + self.vision_encoder.train(mode) + super().train(mode) + + def trigger_pre_te( + self, + tensors_0_1: torch.Tensor, + is_training=False, + has_been_preprocessed=False, + quad_count=4, + batch_size=1, + ) -> PromptEmbeds: + if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + skip_unconditional = self.sd_ref().is_flux + if tensors_0_1 is None: + tensors_0_1 = self.get_empty_clip_image(batch_size) + has_been_preprocessed = True + + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 + + # if is pixtral + if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size: + # get the random size + random_size = random.randint(256, self.config.pixtral_max_image_size) + # images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther. + h, w = clip_image.shape[2], clip_image.shape[3] + current_base_size = int(math.sqrt(w * h)) + ratio = current_base_size / random_size + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + + width_tokens = (w - 1) // self.image_processor.image_patch_size + 1 + height_tokens = (h - 1) // self.image_processor.image_patch_size + 1 + assert width_tokens > 0 + assert height_tokens > 0 + + new_image_size = ( + width_tokens * self.image_processor.image_patch_size, + height_tokens * self.image_processor.image_patch_size, + ) + + # resize the image + clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False) + + + batch_size = clip_image.shape[0] + if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional: + # add an unconditional so we can save it + unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( + clip_image.device, dtype=clip_image.dtype + ) + clip_image = torch.cat([unconditional, clip_image], dim=0) + + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + if self.adapter_type == 'ilora': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + img_embeds = id_embeds.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + img_embeds = id_embeds.hidden_states[-1] + elif self.config.clip_layer == 'image_embeds': + img_embeds = id_embeds.image_embeds + else: + raise ValueError(f"unknown clip layer: {self.config.clip_layer}") + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + self.ilora_module(img_embeds) + if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + else: + with torch.no_grad(): + self.vision_encoder.eval() + clip_output = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + if hasattr(clip_output, 'image_embeds'): + clip_image_embeds = clip_output.image_embeds + elif hasattr(clip_output, 'pooler_output'): + clip_image_embeds = clip_output.pooler_output + # TODO should we always norm image embeds? + # get norm embeddings + # l2_norm = torch.norm(clip_image_embeds, p=2) + # clip_image_embeds = clip_image_embeds / l2_norm + + if not is_training or not self.config.train_image_encoder: + clip_image_embeds = clip_image_embeds.detach() + + if self.adapter_type == 'te_augmenter': + clip_image_embeds = self.te_augmenter(clip_image_embeds) + + if self.adapter_type == 'vision_direct': + clip_image_embeds = self.vd_adapter(clip_image_embeds) + + # save them to the conditional and unconditional + try: + if skip_unconditional: + self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds + else: + self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) + except ValueError: + raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}") + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + yield from self.vision_encoder.parameters(recurse) + return + if self.config.type == 'photo_maker': + yield from self.fuse_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'clip_fusion': + yield from self.clip_fusion_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'ilora': + yield from self.ilora_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'text_encoder': + for attn_processor in self.te_adapter.adapter_modules: + yield from attn_processor.parameters(recurse) + elif self.config.type == 'vision_direct': + if self.config.train_scaler: + # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules) + yield self.vd_adapter.block_scaler + else: + for attn_processor in self.vd_adapter.adapter_modules: + yield from attn_processor.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + if self.vd_adapter.resampler is not None: + yield from self.vd_adapter.resampler.parameters(recurse) + if self.vd_adapter.pool is not None: + yield from self.vd_adapter.pool.parameters(recurse) + if self.vd_adapter.sparse_autoencoder is not None: + yield from self.vd_adapter.sparse_autoencoder.parameters(recurse) + elif self.config.type == 'te_augmenter': + yield from self.te_augmenter.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'single_value': + yield from self.single_value_adapter.parameters(recurse) + elif self.config.type == 'redux': + yield from self.redux_adapter.parameters(recurse) + else: + raise NotImplementedError + + def enable_gradient_checkpointing(self): + if hasattr(self.vision_encoder, "enable_gradient_checkpointing"): + self.vision_encoder.enable_gradient_checkpointing() + elif hasattr(self.vision_encoder, 'gradient_checkpointing'): + self.vision_encoder.gradient_checkpointing = True + + def get_additional_save_metadata(self) -> Dict[str, Any]: + additional = {} + if self.config.type == 'ilora': + extra = self.ilora_module.get_additional_save_metadata() + for k, v in extra.items(): + additional[k] = v + additional['clip_layer'] = self.config.clip_layer + additional['image_encoder_arch'] = self.config.head_dim + return additional + + def post_weight_update(self): + # do any kind of updates after the weight update + if self.config.type == 'vision_direct': + self.vd_adapter.post_weight_update() + pass \ No newline at end of file diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..5285b371734e874c831116fb6799fa99f1df5d48 --- /dev/null +++ b/toolkit/data_loader.py @@ -0,0 +1,677 @@ +import copy +import json +import os +import random +import traceback +from functools import lru_cache +from typing import List, TYPE_CHECKING + +import cv2 +import numpy as np +import torch +from PIL import Image +from PIL.ImageOps import exif_transpose +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader, ConcatDataset +from tqdm import tqdm +import albumentations as A + +from toolkit.buckets import get_bucket_for_image_size, BucketResolution +from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO + +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +class RescaleTransform: + """Transform to rescale images to the range [-1, 1].""" + + def __call__(self, image): + return image * 2 - 1 + + +class NormalizeSDXLTransform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([ 0.0002, -0.1034, -0.1879]) + Standard Deviation: tensor([0.5436, 0.5116, 0.5033]) + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[0.0002, -0.1034, -0.1879], + std=[0.5436, 0.5116, 0.5033], + )(image) + + +class NormalizeSD15Transform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([-0.1600, -0.2450, -0.3227]) + Standard Deviation: tensor([0.5319, 0.4997, 0.5139]) + + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[-0.1600, -0.2450, -0.3227], + std=[0.5319, 0.4997, 0.5139], + )(image) + + + +class ImageDataset(Dataset, CaptionMixin): + def __init__(self, config): + self.config = config + self.name = self.get_config('name', 'dataset') + self.path = self.get_config('path', required=True) + self.scale = self.get_config('scale', 1) + self.random_scale = self.get_config('random_scale', False) + self.include_prompt = self.get_config('include_prompt', False) + self.default_prompt = self.get_config('default_prompt', '') + if self.include_prompt: + self.caption_type = self.get_config('caption_ext', 'txt') + else: + self.caption_type = None + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) + + self.resolution = self.get_config('resolution', 256) + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + + # this might take a while + print(f" - Preprocessing image dimensions") + new_file_list = [] + bad_count = 0 + for file in tqdm(self.file_list): + img = Image.open(file) + if int(min(img.size) * self.scale) >= self.resolution: + new_file_list.append(file) + else: + bad_count += 1 + + self.file_list = new_file_list + + print(f" - Found {len(self.file_list)} images") + print(f" - Found {bad_count} images that are too small") + assert len(self.file_list) > 0, f"no images found in {self.path}" + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + img_path = self.file_list[index] + try: + img = exif_transpose(Image.open(img_path)).convert('RGB') + except Exception as e: + print(f"Error opening image: {img_path}") + print(e) + # make a noise image if we can't open it + img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) + + # Downscale the source image first + img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + min_img_size = min(img.size) + + if self.random_crop: + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + + img = self.transform(img) + + if self.include_prompt: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img + + + + + +class AugmentedImageDataset(ImageDataset): + def __init__(self, config): + super().__init__(config) + self.augmentations = self.get_config('augmentations', []) + self.augmentations = [Augments(**aug) for aug in self.augmentations] + + augmentation_list = [] + for aug in self.augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.aug_transform = A.Compose(augmentation_list) + self.original_transform = self.transform + # replace transform so we get raw pil image + self.transform = transforms.Compose([]) + + def __getitem__(self, index): + # get the original image + # image is a PIL image, convert to bgr + pil_image = super().__getitem__(index) + open_cv_image = np.array(pil_image) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + augmented = self.aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # return both # return image as 0 - 1 tensor + return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) + + +class PairedImageDataset(Dataset): + def __init__(self, config): + super().__init__() + self.config = config + self.size = self.get_config('size', 512) + self.path = self.get_config('path', None) + self.pos_folder = self.get_config('pos_folder', None) + self.neg_folder = self.get_config('neg_folder', None) + + self.default_prompt = self.get_config('default_prompt', '') + self.network_weight = self.get_config('network_weight', 1.0) + self.pos_weight = self.get_config('pos_weight', self.network_weight) + self.neg_weight = self.get_config('neg_weight', self.network_weight) + + supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP') + + if self.pos_folder is not None and self.neg_folder is not None: + # find matching files + self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if + file.lower().endswith(supported_exts)] + self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if + file.lower().endswith(supported_exts)] + + matched_files = [] + for pos_file in self.pos_file_list: + pos_file_no_ext = os.path.splitext(pos_file)[0] + for neg_file in self.neg_file_list: + neg_file_no_ext = os.path.splitext(neg_file)[0] + if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext): + matched_files.append((neg_file, pos_file)) + break + + # remove duplicates + matched_files = [t for t in (set(tuple(i) for i in matched_files))] + + self.file_list = matched_files + print(f" - Found {len(self.file_list)} matching pairs") + else: + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(supported_exts)] + print(f" - Found {len(self.file_list)} images") + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + def get_all_prompts(self): + prompts = [] + for index in range(len(self.file_list)): + prompts.append(self.get_prompt_item(index)) + + # remove duplicates + prompts = list(set(prompts)) + return prompts + + def __len__(self): + return len(self.file_list) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def get_prompt_item(self, index): + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] + prompt_path = path_no_ext + '.txt' + if not os.path.exists(prompt_path): + path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] + prompt_path = path_no_ext + '.txt' + else: + img_path = img_path_or_tuple + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + '.txt' + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = self.default_prompt + return prompt + + def __getitem__(self, index): + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # load both images + img_path = img_path_or_tuple[0] + img1 = exif_transpose(Image.open(img_path)).convert('RGB') + img_path = img_path_or_tuple[1] + img2 = exif_transpose(Image.open(img_path)).convert('RGB') + + # always use # 2 (pos) + bucket_resolution = get_bucket_for_image_size( + width=img2.width, + height=img2.height, + resolution=self.size, + # divisibility=self. + ) + + # images will be same base dimension, but may be trimmed. We need to shrink and then central crop + if bucket_resolution['width'] > bucket_resolution['height']: + img1_scale_to_height = bucket_resolution["height"] + img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height)) + img2_scale_to_height = bucket_resolution["height"] + img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height)) + else: + img1_scale_to_width = bucket_resolution["width"] + img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width)) + img2_scale_to_width = bucket_resolution["width"] + img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width)) + + img1_crop_height = bucket_resolution["height"] + img1_crop_width = bucket_resolution["width"] + img2_crop_height = bucket_resolution["height"] + img2_crop_width = bucket_resolution["width"] + + # scale then center crop images + img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC) + img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1) + img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC) + img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2) + + # combine them side by side + img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) + img.paste(img1, (0, 0)) + img.paste(img2, (img1.width, 0)) + else: + img_path = img_path_or_tuple + img = exif_transpose(Image.open(img_path)).convert('RGB') + height = self.size + # determine width to keep aspect ratio + width = int(img.size[0] * height / img.size[1]) + + # Downscale the source image first + img = img.resize((width, height), Image.BICUBIC) + + prompt = self.get_prompt_item(index) + img = self.transform(img) + + return img, prompt, (self.neg_weight, self.pos_weight) + + +class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): + + def __init__( + self, + dataset_config: 'DatasetConfig', + batch_size=1, + sd: 'StableDiffusion' = None, + ): + super().__init__() + self.dataset_config = dataset_config + folder_path = dataset_config.folder_path + self.dataset_path = dataset_config.dataset_path + if self.dataset_path is None: + self.dataset_path = folder_path + + self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk + self.is_caching_latents_to_memory = dataset_config.cache_latents + self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk + self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk + self.epoch_num = 0 + + self.sd = sd + + if self.sd is None and self.is_caching_latents: + raise ValueError(f"sd is required for caching latents") + + self.caption_type = dataset_config.caption_ext + self.default_caption = dataset_config.default_caption + self.random_scale = dataset_config.random_scale + self.scale = dataset_config.scale + self.batch_size = batch_size + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop + self.resolution = dataset_config.resolution + self.caption_dict = None + self.file_list: List['FileItemDTO'] = [] + + # check if dataset_path is a folder or json + if os.path.isdir(self.dataset_path): + file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + else: + # assume json + with open(self.dataset_path, 'r') as f: + self.caption_dict = json.load(f) + # keys are file paths + file_list = list(self.caption_dict.keys()) + + if self.dataset_config.num_repeats > 1: + # repeat the list + file_list = file_list * self.dataset_config.num_repeats + + if self.dataset_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + NormalizeMethod = NormalizeSDXLTransform + else: + NormalizeMethod = NormalizeSD15Transform + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + NormalizeMethod(), + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + # this might take a while + print(f"Dataset: {self.dataset_path}") + print(f" - Preprocessing image dimensions") + dataset_folder = self.dataset_path + if not os.path.isdir(self.dataset_path): + dataset_folder = os.path.dirname(dataset_folder) + dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') + dataloader_version = "0.1.1" + if os.path.exists(dataset_size_file): + try: + with open(dataset_size_file, 'r') as f: + self.size_database = json.load(f) + + if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: + print("Upgrading size database to new version") + # old version, delete and recreate + self.size_database = {} + except Exception as e: + print(f"Error loading size database: {dataset_size_file}") + print(e) + self.size_database = {} + else: + self.size_database = {} + + self.size_database["__version__"] = dataloader_version + + bad_count = 0 + for file in tqdm(file_list): + try: + file_item = FileItemDTO( + sd=self.sd, + path=file, + dataset_config=dataset_config, + dataloader_transforms=self.transform, + size_database=self.size_database, + dataset_root=dataset_folder, + ) + self.file_list.append(file_item) + except Exception as e: + print(traceback.format_exc()) + print(f"Error processing image: {file}") + print(e) + bad_count += 1 + + # save the size database + with open(dataset_size_file, 'w') as f: + json.dump(self.size_database, f) + + print(f" - Found {len(self.file_list)} images") + # print(f" - Found {bad_count} images that are too small") + assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" + + # handle x axis flips + if self.dataset_config.flip_x: + print(" - adding x axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the x axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_x = True + self.file_list.append(new_file_item) + + # handle y axis flips + if self.dataset_config.flip_y: + print(" - adding y axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the y axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_y = True + self.file_list.append(new_file_item) + + if self.dataset_config.flip_x or self.dataset_config.flip_y: + print(f" - Found {len(self.file_list)} images after adding flips") + + + self.setup_epoch() + + def setup_epoch(self): + if self.epoch_num == 0: + # initial setup + # do not call for now + if self.dataset_config.buckets: + # setup buckets + self.setup_buckets() + if self.is_caching_latents: + self.cache_latents_all_latents() + if self.is_caching_clip_vision_to_disk: + self.cache_clip_vision_to_disk() + else: + if self.dataset_config.poi is not None: + # handle cropping to a specific point of interest + # setup buckets every epoch + self.setup_buckets(quiet=True) + self.epoch_num += 1 + + def __len__(self): + if self.dataset_config.buckets: + return len(self.batch_indices) + return len(self.file_list) + + def _get_single_item(self, index) -> 'FileItemDTO': + file_item = copy.deepcopy(self.file_list[index]) + file_item.load_and_process_image(self.transform) + file_item.load_caption(self.caption_dict) + return file_item + + def __getitem__(self, item): + if self.dataset_config.buckets: + # for buckets we collate ourselves for now + # todo allow a scheduler to dynamically make buckets + # we collate ourselves + if len(self.batch_indices) - 1 < item: + # tried everything to solve this. No way to reset length when redoing things. Pick another index + item = random.randint(0, len(self.batch_indices) - 1) + idx_list = self.batch_indices[item] + return [self._get_single_item(idx) for idx in idx_list] + else: + # Dataloader is batching + return self._get_single_item(item) + + +def get_dataloader_from_datasets( + dataset_options, + batch_size=1, + sd: 'StableDiffusion' = None, +) -> DataLoader: + if dataset_options is None or len(dataset_options) == 0: + return None + + datasets = [] + has_buckets = False + is_caching_latents = False + + dataset_config_list = [] + # preprocess them all + for dataset_option in dataset_options: + if isinstance(dataset_option, DatasetConfig): + dataset_config_list.append(dataset_option) + else: + # preprocess raw data + split_configs = preprocess_dataset_raw_config([dataset_option]) + for x in split_configs: + dataset_config_list.append(DatasetConfig(**x)) + + for config in dataset_config_list: + + if config.type == 'image': + dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd) + datasets.append(dataset) + if config.buckets: + has_buckets = True + if config.cache_latents or config.cache_latents_to_disk: + is_caching_latents = True + else: + raise ValueError(f"invalid dataset type: {config.type}") + + concatenated_dataset = ConcatDataset(datasets) + + # todo build scheduler that can get buckets from all datasets that match + # todo and evenly distribute reg images + + def dto_collation(batch: List['FileItemDTO']): + # create DTO batch + batch = DataLoaderBatchDTO( + file_items=batch + ) + return batch + + # check if is caching latents + + dataloader_kwargs = {} + + if is_native_windows(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor + + if has_buckets: + # make sure they all have buckets + for dataset in datasets: + assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" + + data_loader = DataLoader( + concatenated_dataset, + batch_size=None, # we batch in the datasets for now + drop_last=False, + shuffle=True, + collate_fn=dto_collation, # Use the custom collate function + **dataloader_kwargs + ) + else: + data_loader = DataLoader( + concatenated_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=dto_collation, + **dataloader_kwargs + ) + return data_loader + + +def trigger_dataloader_setup_epoch(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + dataloader.len = None + if isinstance(dataloader.dataset, list): + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None + elif hasattr(dataset, 'setup_epoch'): + dataset.setup_epoch() + dataset.len = None + elif hasattr(dataloader.dataset, 'setup_epoch'): + dataloader.dataset.setup_epoch() + dataloader.dataset.len = None + elif hasattr(dataloader.dataset, 'datasets'): + dataloader.dataset.len = None + for sub_dataset in dataloader.dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None + +def get_dataloader_datasets(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + if isinstance(dataloader.dataset, list): + datasets = [] + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + datasets.append(sub_dataset) + else: + datasets.append(dataset) + return datasets + elif hasattr(dataloader.dataset, 'datasets'): + return dataloader.dataset.datasets + else: + return [dataloader.dataset] diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..34239f4066a2dd885e27ff33e50e26391c17322d --- /dev/null +++ b/toolkit/data_transfer_object/data_loader.py @@ -0,0 +1,252 @@ +import os +import weakref +from _weakref import ReferenceType +from typing import TYPE_CHECKING, List, Union +import torch +import random + +from PIL import Image +from PIL.ImageOps import exif_transpose + +from toolkit import image_utils +from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ + ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ + UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin + + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + from toolkit.stable_diffusion_model import StableDiffusion + +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + + +class FileItemDTO( + LatentCachingFileItemDTOMixin, + CaptionProcessingDTOMixin, + ImageProcessingDTOMixin, + ControlFileItemDTOMixin, + ClipImageFileItemDTOMixin, + MaskFileItemDTOMixin, + AugmentationFileItemDTOMixin, + UnconditionalFileItemDTOMixin, + PoiFileItemDTOMixin, + ArgBreakMixin, +): + def __init__(self, *args, **kwargs): + self.path = kwargs.get('path', '') + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + size_database = kwargs.get('size_database', {}) + dataset_root = kwargs.get('dataset_root', None) + if dataset_root is not None: + # remove dataset root from path + file_key = self.path.replace(dataset_root, '') + else: + file_key = os.path.basename(self.path) + if file_key in size_database: + w, h = size_database[file_key] + else: + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. + # process width and height + # try: + # w, h = image_utils.get_image_size(self.path) + # except image_utils.UnknownImageFormat: + # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + # f'This process is faster for png, jpeg') + img = exif_transpose(Image.open(self.path)) + w, h = img.size + size_database[file_key] = (w, h) + self.width: int = w + self.height: int = h + self.dataloader_transforms = kwargs.get('dataloader_transforms', None) + super().__init__(*args, **kwargs) + + # self.caption_path: str = kwargs.get('caption_path', None) + self.raw_caption: str = kwargs.get('raw_caption', None) + # we scale first, then crop + self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale)) + self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale)) + # crop values are from scaled size + self.crop_x: int = kwargs.get('crop_x', 0) + self.crop_y: int = kwargs.get('crop_y', 0) + self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) + self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) + self.flip_x: bool = kwargs.get('flip_x', False) + self.flip_y: bool = kwargs.get('flip_x', False) + self.augments: List[str] = self.dataset_config.augments + self.loss_multiplier: float = self.dataset_config.loss_multiplier + + self.network_weight: float = self.dataset_config.network_weight + self.is_reg = self.dataset_config.is_reg + self.tensor: Union[torch.Tensor, None] = None + + def cleanup(self): + self.tensor = None + self.cleanup_latent() + self.cleanup_control() + self.cleanup_clip_image() + self.cleanup_mask() + self.cleanup_unconditional() + + +class DataLoaderBatchDTO: + def __init__(self, **kwargs): + try: + self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) + is_latents_cached = self.file_items[0].is_latent_cached + self.tensor: Union[torch.Tensor, None] = None + self.latents: Union[torch.Tensor, None] = None + self.control_tensor: Union[torch.Tensor, None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None + self.unaugmented_tensor: Union[torch.Tensor, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latents: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[List[dict], None] = None + self.clip_image_embeds_unconditional: Union[List[dict], None] = None + self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code + self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None + if not is_latents_cached: + # only return a tensor if latents are not cached + self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + # if we have encoded latents, we concatenate them + self.latents: Union[torch.Tensor, None] = None + if is_latents_cached: + self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) + self.control_tensor: Union[torch.Tensor, None] = None + # if self.file_items[0].control_tensor is not None: + # if any have a control tensor, we concatenate them + if any([x.control_tensor is not None for x in self.file_items]): + # find one to use as a base + base_control_tensor = None + for x in self.file_items: + if x.control_tensor is not None: + base_control_tensor = x.control_tensor + break + control_tensors = [] + for x in self.file_items: + if x.control_tensor is None: + control_tensors.append(torch.zeros_like(base_control_tensor)) + else: + control_tensors.append(x.control_tensor) + self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] + + if any([x.clip_image_tensor is not None for x in self.file_items]): + # find one to use as a base + base_clip_image_tensor = None + for x in self.file_items: + if x.clip_image_tensor is not None: + base_clip_image_tensor = x.clip_image_tensor + break + clip_image_tensors = [] + for x in self.file_items: + if x.clip_image_tensor is None: + clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor)) + else: + clip_image_tensors.append(x.clip_image_tensor) + self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors]) + + if any([x.mask_tensor is not None for x in self.file_items]): + # find one to use as a base + base_mask_tensor = None + for x in self.file_items: + if x.mask_tensor is not None: + base_mask_tensor = x.mask_tensor + break + mask_tensors = [] + for x in self.file_items: + if x.mask_tensor is None: + mask_tensors.append(torch.zeros_like(base_mask_tensor)) + else: + mask_tensors.append(x.mask_tensor) + self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) + + # add unaugmented tensors for ones with augments + if any([x.unaugmented_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unaugmented_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unaugmented_tensor = x.unaugmented_tensor + break + unaugmented_tensor = [] + for x in self.file_items: + if x.unaugmented_tensor is None: + unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor)) + else: + unaugmented_tensor.append(x.unaugmented_tensor) + self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) + + # add unconditional tensors + if any([x.unconditional_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unconditional_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unconditional_tensor = x.unconditional_tensor + break + unconditional_tensor = [] + for x in self.file_items: + if x.unconditional_tensor is None: + unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor)) + else: + unconditional_tensor.append(x.unconditional_tensor) + self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) + + if any([x.clip_image_embeds is not None for x in self.file_items]): + self.clip_image_embeds = [] + for x in self.file_items: + if x.clip_image_embeds is not None: + self.clip_image_embeds.append(x.clip_image_embeds) + else: + raise Exception("clip_image_embeds is None for some file items") + + if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]): + self.clip_image_embeds_unconditional = [] + for x in self.file_items: + if x.clip_image_embeds_unconditional is not None: + self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) + else: + raise Exception("clip_image_embeds_unconditional is None for some file items") + + except Exception as e: + print(e) + raise e + + def get_is_reg_list(self): + return [x.is_reg for x in self.file_items] + + def get_network_weight_list(self): + return [x.network_weight for x in self.file_items] + + def get_caption_list( + self, + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + return [x.caption for x in self.file_items] + + def get_caption_short_list( + self, + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + return [x.caption_short for x in self.file_items] + + def cleanup(self): + del self.latents + del self.tensor + del self.control_tensor + for file_item in self.file_items: + file_item.cleanup() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..1bba44316952858c0cbf870cb70ec85db4c8fb99 --- /dev/null +++ b/toolkit/dataloader_mixins.py @@ -0,0 +1,1630 @@ +import base64 +import glob +import hashlib +import json +import math +import os +import random +from collections import OrderedDict +from typing import TYPE_CHECKING, List, Dict, Union + +import cv2 +import numpy as np +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor + +from toolkit.basic import flush, value_map +from toolkit.buckets import get_bucket_for_image_size, get_resolution +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible +from toolkit.prompt_utils import inject_trigger_into_prompt +from torchvision import transforms +from PIL import Image, ImageFilter, ImageOps +from PIL.ImageOps import exif_transpose +import albumentations as A + +from toolkit.train_tools import get_torch_dtype + +if TYPE_CHECKING: + from toolkit.data_loader import AiToolkitDataset + from toolkit.data_transfer_object.data_loader import FileItemDTO + from toolkit.stable_diffusion_model import StableDiffusion + +# def get_associated_caption_from_img_path(img_path): +# https://demo.albumentations.ai/ +class Augments: + def __init__(self, **kwargs): + self.method_name = kwargs.get('method', None) + self.params = kwargs.get('params', {}) + + # convert kwargs enums for cv2 + for key, value in self.params.items(): + if isinstance(value, str): + # split the string + split_string = value.split('.') + if len(split_string) == 2 and split_string[0] == 'cv2': + if hasattr(cv2, split_string[1]): + self.params[key] = getattr(cv2, split_string[1].upper()) + else: + raise ValueError(f"invalid cv2 enum: {split_string[1]}") + + +transforms_dict = { + 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), + 'RandomEqualize': transforms.RandomEqualize(p=0.2), +} + +caption_ext_list = ['txt', 'json', 'caption'] + + +def standardize_images(images): + """ + Standardize the given batch of images using the specified mean and std. + Expects values of 0 - 1 + + Args: + images (torch.Tensor): A batch of images in the shape of (N, C, H, W), + where N is the number of images, C is the number of channels, + H is the height, and W is the width. + + Returns: + torch.Tensor: Standardized images. + """ + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # Define the normalization transform + normalize = transforms.Normalize(mean=mean, std=std) + + # Apply normalization to each image in the batch + standardized_images = torch.stack([normalize(img) for img in images]) + + return standardized_images + +def clean_caption(caption): + # remove any newlines + caption = caption.replace('\n', ', ') + # remove new lines for all operating systems + caption = caption.replace('\r', ', ') + caption_split = caption.split(',') + # remove empty strings + caption_split = [p.strip() for p in caption_split if p.strip()] + # join back together + caption = ', '.join(caption_split) + return caption + + +class CaptionMixin: + def get_caption_item(self: 'AiToolkitDataset', index): + if not hasattr(self, 'caption_type'): + raise Exception('caption_type not found on class instance') + if not hasattr(self, 'file_list'): + raise Exception('file_list not found on class instance') + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = None + for ext in caption_ext_list: + prompt_path = path_no_ext + '.' + ext + if os.path.exists(prompt_path): + break + else: + img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = None + for ext in caption_ext_list: + prompt_path = path_no_ext + '.' + ext + if os.path.exists(prompt_path): + break + + # allow folders to have a default prompt + default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt') + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # check if is json + if prompt_path.endswith('.json'): + prompt = json.loads(prompt) + if 'caption' in prompt: + prompt = prompt['caption'] + + prompt = clean_caption(prompt) + elif os.path.exists(default_prompt_path): + with open(default_prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + prompt = clean_caption(prompt) + else: + prompt = '' + # get default_prompt if it exists on the class instance + if hasattr(self, 'default_prompt'): + prompt = self.default_prompt + if hasattr(self, 'default_caption'): + prompt = self.default_caption + + # handle replacements + replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else [] + for replacement in replacement_list: + from_string, to_string = replacement.split('|') + prompt = prompt.replace(from_string, to_string) + + return prompt + + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + from toolkit.data_transfer_object.data_loader import FileItemDTO + + +class Bucket: + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.file_list_idx: List[int] = [] + + +class BucketsMixin: + def __init__(self): + self.buckets: Dict[str, Bucket] = {} + self.batch_indices: List[List[int]] = [] + + def build_batch_indices(self: 'AiToolkitDataset'): + self.batch_indices = [] + for key, bucket in self.buckets.items(): + for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): + end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) + batch = bucket.file_list_idx[start_idx:end_idx] + self.batch_indices.append(batch) + + def shuffle_buckets(self: 'AiToolkitDataset'): + for key, bucket in self.buckets.items(): + random.shuffle(bucket.file_list_idx) + + def setup_buckets(self: 'AiToolkitDataset', quiet=False): + if not hasattr(self, 'file_list'): + raise Exception(f'file_list not found on class instance {self.__class__.__name__}') + if not hasattr(self, 'dataset_config'): + raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') + + if self.epoch_num > 0 and self.dataset_config.poi is None: + # no need to rebuild buckets for now + # todo handle random cropping for buckets + return + self.buckets = {} # clear it + + config: 'DatasetConfig' = self.dataset_config + resolution = config.resolution + bucket_tolerance = config.bucket_tolerance + file_list: List['FileItemDTO'] = self.file_list + + # for file_item in enumerate(file_list): + for idx, file_item in enumerate(file_list): + file_item: 'FileItemDTO' = file_item + width = int(file_item.width * file_item.dataset_config.scale) + height = int(file_item.height * file_item.dataset_config.scale) + + did_process_poi = False + if file_item.has_point_of_interest: + # Attempt to process the poi if we can. It wont process if the image is smaller than the resolution + did_process_poi = file_item.setup_poi_bucket() + if self.dataset_config.square_crop: + # we scale first so smallest size matches resolution + scale_factor_x = resolution / width + scale_factor_y = resolution / height + scale_factor = max(scale_factor_x, scale_factor_y) + file_item.scale_to_width = math.ceil(width * scale_factor) + file_item.scale_to_height = math.ceil(height * scale_factor) + file_item.crop_width = resolution + file_item.crop_height = resolution + if width > height: + file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2) + file_item.crop_y = 0 + else: + file_item.crop_x = 0 + file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2) + elif not did_process_poi: + bucket_resolution = get_bucket_for_image_size( + width, height, + resolution=resolution, + divisibility=bucket_tolerance + ) + + # Calculate scale factors for width and height + width_scale_factor = bucket_resolution["width"] / width + height_scale_factor = bucket_resolution["height"] / height + + # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution + max_scale_factor = max(width_scale_factor, height_scale_factor) + + # round up + file_item.scale_to_width = int(math.ceil(width * max_scale_factor)) + file_item.scale_to_height = int(math.ceil(height * max_scale_factor)) + + file_item.crop_height = bucket_resolution["height"] + file_item.crop_width = bucket_resolution["width"] + + new_width = bucket_resolution["width"] + new_height = bucket_resolution["height"] + + if self.dataset_config.random_crop: + # random crop + crop_x = random.randint(0, file_item.scale_to_width - new_width) + crop_y = random.randint(0, file_item.scale_to_height - new_height) + file_item.crop_x = crop_x + file_item.crop_y = crop_y + else: + # do central crop + file_item.crop_x = int((file_item.scale_to_width - new_width) / 2) + file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) + + if file_item.crop_y < 0 or file_item.crop_x < 0: + print('debug') + + # check if bucket exists, if not, create it + bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' + if bucket_key not in self.buckets: + self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height) + self.buckets[bucket_key].file_list_idx.append(idx) + + # print the buckets + self.shuffle_buckets() + self.build_batch_indices() + if not quiet: + print(f'Bucket sizes for {self.dataset_path}:') + for key, bucket in self.buckets.items(): + print(f'{key}: {len(bucket.file_list_idx)} files') + print(f'{len(self.buckets)} buckets made') + + +class CaptionProcessingDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.raw_caption: str = None + self.raw_caption_short: str = None + self.caption: str = None + self.caption_short: str = None + + dataset_config: DatasetConfig = kwargs.get('dataset_config', None) + self.extra_values: List[float] = dataset_config.extra_values + + # todo allow for loading from sd-scripts style dict + def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): + if self.raw_caption is not None: + # we already loaded it + pass + elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: + self.raw_caption = caption_dict[self.path]["caption"] + if 'caption_short' in caption_dict[self.path]: + self.raw_caption_short = caption_dict[self.path]["caption_short"] + else: + # see if prompt file exists + path_no_ext = os.path.splitext(self.path)[0] + prompt_ext = self.dataset_config.caption_ext + prompt_path = f"{path_no_ext}.{prompt_ext}" + short_caption = None + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + short_caption = None + if prompt_path.endswith('.json'): + # replace any line endings with commas for \n \r \r\n + prompt = prompt.replace('\r\n', ' ') + prompt = prompt.replace('\n', ' ') + prompt = prompt.replace('\r', ' ') + + prompt_json = json.loads(prompt) + if 'caption' in prompt_json: + prompt = prompt_json['caption'] + if 'caption_short' in prompt_json: + short_caption = prompt_json['caption_short'] + + if 'extra_values' in prompt_json: + self.extra_values = prompt_json['extra_values'] + + prompt = clean_caption(prompt) + if short_caption is not None: + short_caption = clean_caption(short_caption) + else: + prompt = '' + if self.dataset_config.default_caption is not None: + prompt = self.dataset_config.default_caption + + if short_caption is None: + short_caption = self.dataset_config.default_caption + self.raw_caption = prompt + self.raw_caption_short = short_caption + + self.caption = self.get_caption() + if self.raw_caption_short is not None: + self.caption_short = self.get_caption(short_caption=True) + + def get_caption( + self: 'FileItemDTO', + trigger=None, + to_replace_list=None, + add_if_not_present=False, + short_caption=False + ): + if short_caption: + raw_caption = self.raw_caption_short + else: + raw_caption = self.raw_caption + if raw_caption is None: + raw_caption = '' + # handle dropout + if self.dataset_config.caption_dropout_rate > 0 and not short_caption: + # get a random float form 0 to 1 + rand = random.random() + if rand < self.dataset_config.caption_dropout_rate: + # drop the caption + return '' + + # get tokens + token_list = raw_caption.split(',') + # trim whitespace + token_list = [x.strip() for x in token_list] + # remove empty strings + token_list = [x for x in token_list if x] + + # handle token dropout + if self.dataset_config.token_dropout_rate > 0 and not short_caption: + new_token_list = [] + keep_tokens: int = self.dataset_config.keep_tokens + for idx, token in enumerate(token_list): + if idx < keep_tokens: + new_token_list.append(token) + elif self.dataset_config.token_dropout_rate >= 1.0: + # drop the token + pass + else: + # get a random float form 0 to 1 + rand = random.random() + if rand > self.dataset_config.token_dropout_rate: + # keep the token + new_token_list.append(token) + token_list = new_token_list + + if self.dataset_config.shuffle_tokens: + random.shuffle(token_list) + + # join back together + caption = ', '.join(token_list) + # caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + + if self.dataset_config.random_triggers: + num_triggers = self.dataset_config.random_triggers_max + if num_triggers > 1: + num_triggers = random.randint(0, num_triggers) + + if num_triggers > 0: + triggers = random.sample(self.dataset_config.random_triggers, num_triggers) + caption = caption + ', ' + ', '.join(triggers) + # add random triggers + # for i in range(num_triggers): + # # fastest method + # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] + # caption = caption + ', ' + trigger + + if self.dataset_config.shuffle_tokens: + # shuffle again + token_list = caption.split(',') + # trim whitespace + token_list = [x.strip() for x in token_list] + # remove empty strings + token_list = [x for x in token_list if x] + random.shuffle(token_list) + caption = ', '.join(token_list) + + return caption + + +class ImageProcessingDTOMixin: + def load_and_process_image( + self: 'FileItemDTO', + transform: Union[None, transforms.Compose], + only_load_latents=False + ): + # if we are caching latents, just do that + if self.is_latent_cached: + self.get_latent() + if self.has_control_image: + self.load_control_image() + if self.has_clip_image: + self.load_clip_image() + if self.has_mask_image: + self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() + return + try: + img = Image.open(self.path) + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.path}") + + if self.use_alpha_as_mask: + # we do this to make sure it does not replace the alpha with another color + # we want the image just without the alpha channel + np_img = np.array(img) + # strip off alpha + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + print( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + print( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height + if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: + # todo look into this. This still happens sometimes + print('size mismatch') + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + else: + # Downscale the source image first + # TODO this is nto right + img = img.resize( + (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)), + Image.BICUBIC) + min_img_size = min(img.size) + if self.dataset_config.random_crop: + if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: + if min_img_size < self.dataset_config.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") + scale_size = self.dataset_config.resolution + else: + scale_size = random.randint(self.dataset_config.resolution, int(min_img_size)) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) + img = transforms.RandomCrop(self.dataset_config.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC) + + if self.augments is not None and len(self.augments) > 0: + # do augmentations + for augment in self.augments: + if augment in transforms_dict: + img = transforms_dict[augment](img) + + if self.has_augmentations: + # augmentations handles transforms + img = self.augment_image(img, transform=transform) + elif transform: + img = transform(img) + + self.tensor = img + if not only_load_latents: + if self.has_control_image: + self.load_control_image() + if self.has_clip_image: + self.load_clip_image() + if self.has_mask_image: + self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() + + +class ControlFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_control_image = False + self.control_path: Union[str, None] = None + self.control_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.full_size_control_images = False + if dataset_config.control_path is not None: + # find the control image path + control_path = dataset_config.control_path + self.full_size_control_images = dataset_config.full_size_control_images + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): + self.control_path = os.path.join(control_path, file_name_no_ext + ext) + self.has_control_image = True + break + + def load_control_image(self: 'FileItemDTO'): + try: + img = Image.open(self.control_path).convert('RGB') + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.control_path}") + + if self.full_size_control_images: + # we just scale them to 512x512: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) + + else: + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Control images not supported for non-bucket datasets") + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + self.control_tensor = self.augment_spatial_control(img, transform=transform) + else: + self.control_tensor = transform(img) + + def cleanup_control(self: 'FileItemDTO'): + self.control_tensor = None + + +class ClipImageFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_clip_image = False + self.clip_image_path: Union[str, None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[dict, None] = None + self.clip_image_embeds_unconditional: Union[dict, None] = None + self.has_clip_augmentations = False + self.clip_image_aug_transform: Union[None, A.Compose] = None + self.clip_image_processor: Union[None, CLIPImageProcessor] = None + self.clip_image_encoder_path: Union[str, None] = None + self.is_caching_clip_vision_to_disk = False + self.is_vision_clip_cached = False + self.clip_vision_is_quad = False + self.clip_vision_load_device = 'cpu' + self.clip_vision_unconditional_paths: Union[List[str], None] = None + self._clip_vision_embeddings_path: Union[str, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder: + # copy the clip image processor so the dataloader can do it + sd = kwargs.get('sd', None) + if hasattr(sd.adapter, 'clip_image_processor'): + self.clip_image_processor = sd.adapter.clip_image_processor + if dataset_config.clip_image_path is not None: + # find the control image path + clip_image_path = dataset_config.clip_image_path + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)): + self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext) + self.has_clip_image = True + break + self.build_clip_imag_augmentation_transform() + + if dataset_config.clip_image_from_same_folder: + # assume we have one. We will pull it on load. + self.has_clip_image = True + self.build_clip_imag_augmentation_transform() + + def build_clip_imag_augmentation_transform(self: 'FileItemDTO'): + if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0: + self.has_clip_augmentations = True + augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations] + + if self.dataset_config.clip_image_shuffle_augmentations: + random.shuffle(augmentations) + + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.clip_image_aug_transform = A.Compose(augmentation_list) + + def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + if self.dataset_config.clip_image_shuffle_augmentations: + self.build_clip_imag_augmentation_transform() + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + if self.clip_vision_is_quad: + # image is in a 2x2 gris. split, run augs, and recombine + # split + img1, img2 = np.hsplit(open_cv_image, 2) + img1_1, img1_2 = np.vsplit(img1, 2) + img2_1, img2_2 = np.vsplit(img2, 2) + # apply augmentations + img1_1 = self.clip_image_aug_transform(image=img1_1)["image"] + img1_2 = self.clip_image_aug_transform(image=img1_2)["image"] + img2_1 = self.clip_image_aug_transform(image=img2_1)["image"] + img2_2 = self.clip_image_aug_transform(image=img2_2)["image"] + # recombine + augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2)))) + + else: + # apply augmentations + augmented = self.clip_image_aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + + def get_clip_vision_info_dict(self: 'FileItemDTO'): + item = OrderedDict([ + ("image_encoder_path", self.clip_image_encoder_path), + ("filename", os.path.basename(self.clip_image_path)), + ("is_quad", self.clip_vision_is_quad) + ]) + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + return item + def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False): + if self._clip_vision_embeddings_path is not None and not recalculate: + return self._clip_vision_embeddings_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.clip_image_path) + latent_dir = os.path.join(img_dir, '_clip_vision_cache') + hash_dict = self.get_clip_vision_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._clip_vision_embeddings_path + + def get_new_clip_image_path(self: 'FileItemDTO'): + if self.dataset_config.clip_image_from_same_folder: + # randomly grab an image path from the same folder + pool_folder = os.path.dirname(self.path) + # find all images in the folder + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + img_files = [] + for ext in img_ext_list: + img_files += glob.glob(os.path.join(pool_folder, f'*{ext}')) + # remove the current image if len is greater than 1 + if len(img_files) > 1: + img_files.remove(self.path) + # randomly grab one + return random.choice(img_files) + else: + return self.clip_image_path + + def load_clip_image(self: 'FileItemDTO'): + is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \ + isinstance(self.clip_image_processor, SiglipImageProcessor) + if self.is_vision_clip_cached: + self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) + + # get a random unconditional image + if self.clip_vision_unconditional_paths is not None: + unconditional_path = random.choice(self.clip_vision_unconditional_paths) + self.clip_image_embeds_unconditional = load_file(unconditional_path) + + return + clip_image_path = self.get_new_clip_image_path() + try: + img = Image.open(clip_image_path).convert('RGB') + img = exif_transpose(img) + except Exception as e: + # make a random noise image + img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution)) + print(f"Error: {e}") + print(f"Error loading image: {clip_image_path}") + + img = img.convert('RGB') + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if is_dynamic_size_and_aspect: + pass # let the image processor handle it + elif img.width != img.height: + min_size = min(img.width, img.height) + if self.dataset_config.square_crop: + # center crop to a square + img = transforms.CenterCrop(min_size)(img) + else: + # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data + # resize to the smallest dimension + img = img.resize((min_size, min_size), Image.BICUBIC) + + if self.has_clip_augmentations: + self.clip_image_tensor = self.augment_clip_image(img, transform=None) + else: + self.clip_image_tensor = transforms.ToTensor()(img) + + # random crop + # if self.dataset_config.clip_image_random_crop: + # # crop up to 20% on all sides. Keep is square + # crop_percent = random.randint(0, 20) / 100 + # crop_width = int(self.clip_image_tensor.shape[2] * crop_percent) + # crop_height = int(self.clip_image_tensor.shape[1] * crop_percent) + # crop_left = random.randint(0, crop_width) + # crop_top = random.randint(0, crop_height) + # crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left + # crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top + # if len(self.clip_image_tensor.shape) == 3: + # self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right] + # elif len(self.clip_image_tensor.shape) == 4: + # self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right] + + if self.clip_image_processor is not None: + # run it + tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16) + clip_out = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + self.clip_image_tensor = clip_out.squeeze(0).clone().detach() + + def cleanup_clip_image(self: 'FileItemDTO'): + self.clip_image_tensor = None + self.clip_image_embeds = None + + + + +class AugmentationFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_augmentations = False + self.unaugmented_tensor: Union[torch.Tensor, None] = None + # self.augmentations: Union[None, List[Augments]] = None + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.aug_transform: Union[None, A.Compose] = None + self.aug_replay_spatial_transforms = None + self.build_augmentation_transform() + + def build_augmentation_transform(self: 'FileItemDTO'): + if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0: + self.has_augmentations = True + augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations] + + if self.dataset_config.shuffle_augmentations: + random.shuffle(augmentations) + + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + # add additional targets so we can augment the control image + self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'}) + + def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + + # rebuild each time if shuffle + if self.dataset_config.shuffle_augmentations: + self.build_augmentation_transform() + + # save the original tensor + self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img) + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + transformed = self.aug_transform(image=open_cv_image) + augmented = transformed["image"] + + # save just the spatial transforms for controls and masks + augmented_params = transformed["replay"] + spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop', + 'ElasticTransform', 'GridDistortion', 'OpticalDistortion'] + # only store the spatial transforms + augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms] + + if self.dataset_config.replay_transforms: + self.aug_replay_spatial_transforms = augmented_params + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + + # augment control images spatially consistent with transforms done to the main image + def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ): + if self.aug_replay_spatial_transforms is None: + # no transforms + return transform(img) + + # save colorspace to convert back to + colorspace = img.mode + + # convert to rgb + img = img.convert('RGB') + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # Replay transforms + transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image) + augmented = transformed["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # convert back to original colorspace + augmented = augmented.convert(colorspace) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + return augmented_tensor + + def cleanup_control(self: 'FileItemDTO'): + self.unaugmented_tensor = None + + +class MaskFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_mask_image = False + self.mask_path: Union[str, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None + self.use_alpha_as_mask: bool = False + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.mask_min_value = dataset_config.mask_min_value + if dataset_config.alpha_mask: + self.use_alpha_as_mask = True + self.mask_path = kwargs.get('path', None) + self.has_mask_image = True + elif dataset_config.mask_path is not None: + # find the control image path + mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)): + self.mask_path = os.path.join(mask_path, file_name_no_ext + ext) + self.has_mask_image = True + break + + def load_mask_image(self: 'FileItemDTO'): + try: + img = Image.open(self.mask_path) + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.mask_path}") + + if self.use_alpha_as_mask: + # pipeline expectws an rgb image so we need to put alpha in all channels + np_img = np.array(img) + np_img[:, :, :3] = np_img[:, :, 3:] + + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') + if self.dataset_config.invert_mask: + img = ImageOps.invert(img) + w, h = img.size + fix_size = False + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True + + if fix_size: + # swap all the sizes + self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width + self.crop_width, self.crop_height = self.crop_height, self.crop_width + self.crop_x, self.crop_y = self.crop_y, self.crop_x + + + + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + # randomly apply a blur up to 0.5% of the size of the min (width, height) + min_size = min(img.width, img.height) + blur_radius = int(min_size * random.random() * 0.005) + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # make grayscale + img = img.convert('L') + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Mask images not supported for non-bucket datasets") + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + self.mask_tensor = self.augment_spatial_control(img, transform=transform) + else: + self.mask_tensor = transform(img) + self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0) + # convert to grayscale + + def cleanup_mask(self: 'FileItemDTO'): + self.mask_tensor = None + + +class UnconditionalFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_unconditional = False + self.unconditional_path: Union[str, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latent: Union[torch.Tensor, None] = None + self.unconditional_transforms = self.dataloader_transforms + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + + if dataset_config.unconditional_path is not None: + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)): + self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext) + self.has_unconditional = True + break + + def load_unconditional_image(self: 'FileItemDTO'): + try: + img = Image.open(self.unconditional_path) + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.mask_path}") + + img = img.convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Unconditional images are not supported for non-bucket datasets") + + if self.aug_replay_spatial_transforms: + self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms) + else: + self.unconditional_tensor = self.unconditional_transforms(img) + + def cleanup_unconditional(self: 'FileItemDTO'): + self.unconditional_tensor = None + self.unconditional_latent = None + + +class PoiFileItemDTOMixin: + # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject + # items in the poi will always be inside the image when random cropping + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + # poi is a name of the box point of interest in the caption json file + dataset_config = kwargs.get('dataset_config', None) + path = kwargs.get('path', None) + self.poi: Union[str, None] = dataset_config.poi + self.has_point_of_interest = self.poi is not None + self.poi_x: Union[int, None] = None + self.poi_y: Union[int, None] = None + self.poi_width: Union[int, None] = None + self.poi_height: Union[int, None] = None + + if self.poi is not None: + # make sure latent caching is off + if dataset_config.cache_latents or dataset_config.cache_latents_to_disk: + raise Exception( + f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config" + ) + # make sure we are loading through json + if dataset_config.caption_ext != 'json': + raise Exception( + f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config" + ) + self.poi = self.poi.strip() + # get the caption path + file_path_no_ext = os.path.splitext(path)[0] + caption_path = file_path_no_ext + '.json' + if not os.path.exists(caption_path): + raise Exception(f"Error: caption file not found for poi: {caption_path}") + with open(caption_path, 'r', encoding='utf-8') as f: + json_data = json.load(f) + if 'poi' not in json_data: + print(f"Warning: poi not found in caption file: {caption_path}") + if self.poi not in json_data['poi']: + print(f"Warning: poi not found in caption file: {caption_path}") + # poi has, x, y, width, height + # do full image if no poi + self.poi_x = 0 + self.poi_y = 0 + self.poi_width = self.width + self.poi_height = self.height + try: + if self.poi in json_data['poi']: + poi = json_data['poi'][self.poi] + self.poi_x = int(poi['x']) + self.poi_y = int(poi['y']) + self.poi_width = int(poi['width']) + self.poi_height = int(poi['height']) + except Exception as e: + pass + + # handle flipping + if kwargs.get('flip_x', False): + # flip the poi + self.poi_x = self.width - self.poi_x - self.poi_width + if kwargs.get('flip_y', False): + # flip the poi + self.poi_y = self.height - self.poi_y - self.poi_height + + def setup_poi_bucket(self: 'FileItemDTO'): + initial_width = int(self.width * self.dataset_config.scale) + initial_height = int(self.height * self.dataset_config.scale) + # we are using poi, so we need to calculate the bucket based on the poi + + # if img resolution is less than dataset resolution, just return and let the normal bucketing happen + img_resolution = get_resolution(initial_width, initial_height) + if img_resolution <= self.dataset_config.resolution: + return False # will trigger normal bucketing + + bucket_tolerance = self.dataset_config.bucket_tolerance + poi_x = int(self.poi_x * self.dataset_config.scale) + poi_y = int(self.poi_y * self.dataset_config.scale) + poi_width = int(self.poi_width * self.dataset_config.scale) + poi_height = int(self.poi_height * self.dataset_config.scale) + + # loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better + num_loops = 0 + while True: + # crop left + if poi_x > 0: + poi_x = random.randint(0, poi_x) + else: + poi_x = 0 + + # crop right + cr_min = poi_x + poi_width + if cr_min < initial_width: + crop_right = random.randint(poi_x + poi_width, initial_width) + else: + crop_right = initial_width + + poi_width = crop_right - poi_x + + if poi_y > 0: + poi_y = random.randint(0, poi_y) + else: + poi_y = 0 + + if poi_y + poi_height < initial_height: + crop_bottom = random.randint(poi_y + poi_height, initial_height) + else: + crop_bottom = initial_height + + poi_height = crop_bottom - poi_y + try: + # now we have our random crop, but it may be smaller than resolution. Check and expand if needed + current_resolution = get_resolution(poi_width, poi_height) + except Exception as e: + print(f"Error: {e}") + print(f"Error getting resolution: {self.path}") + raise e + return False + if current_resolution >= self.dataset_config.resolution: + # We can break now + break + else: + num_loops += 1 + if num_loops > 100: + print( + f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.") + return False + + new_width = poi_width + new_height = poi_height + + bucket_resolution = get_bucket_for_image_size( + new_width, new_height, + resolution=self.dataset_config.resolution, + divisibility=bucket_tolerance + ) + + width_scale_factor = bucket_resolution["width"] / new_width + height_scale_factor = bucket_resolution["height"] / new_height + # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution + max_scale_factor = max(width_scale_factor, height_scale_factor) + + self.scale_to_width = math.ceil(initial_width * max_scale_factor) + self.scale_to_height = math.ceil(initial_height * max_scale_factor) + self.crop_width = bucket_resolution['width'] + self.crop_height = bucket_resolution['height'] + self.crop_x = int(poi_x * max_scale_factor) + self.crop_y = int(poi_y * max_scale_factor) + + if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height: + # todo look into this. This still happens sometimes + print('size mismatch') + + return True + + +class ArgBreakMixin: + # just stops super calls form hitting object + def __init__(self, *args, **kwargs): + pass + + +class LatentCachingFileItemDTOMixin: + def __init__(self, *args, **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self._encoded_latent: Union[torch.Tensor, None] = None + self._latent_path: Union[str, None] = None + self.is_latent_cached = False + self.is_caching_to_disk = False + self.is_caching_to_memory = False + self.latent_load_device = 'cpu' + # sd1 or sdxl or others + self.latent_space_version = 'sd1' + # todo, increment this if we change the latent format to invalidate cache + self.latent_version = 1 + + def get_latent_info_dict(self: 'FileItemDTO'): + item = OrderedDict([ + ("filename", os.path.basename(self.path)), + ("scale_to_width", self.scale_to_width), + ("scale_to_height", self.scale_to_height), + ("crop_x", self.crop_x), + ("crop_y", self.crop_y), + ("crop_width", self.crop_width), + ("crop_height", self.crop_height), + ("latent_space_version", self.latent_space_version), + ("latent_version", self.latent_version), + ]) + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + return item + + def get_latent_path(self: 'FileItemDTO', recalculate=False): + if self._latent_path is not None and not recalculate: + return self._latent_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.path) + latent_dir = os.path.join(img_dir, '_latent_cache') + hash_dict = self.get_latent_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._latent_path + + def cleanup_latent(self): + if self._encoded_latent is not None: + if not self.is_caching_to_memory: + # we are caching on disk, don't save in memory + self._encoded_latent = None + else: + # move it back to cpu + self._encoded_latent = self._encoded_latent.to('cpu') + + def get_latent(self, device=None): + if not self.is_latent_cached: + return None + if self._encoded_latent is None: + # load it from disk + state_dict = load_file( + self.get_latent_path(), + # device=device if device is not None else self.latent_load_device + device='cpu' + ) + self._encoded_latent = state_dict['latent'] + return self._encoded_latent + + +class LatentCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.latent_cache = {} + + def cache_latents_all_latents(self: 'AiToolkitDataset'): + print(f"Caching latents for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory + + if to_disk: + print(" - Saving latents to disk") + if to_memory: + print(" - Keeping latents in memory") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = 'sd1' + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device + + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + except Exception as e: + print(f"Error processing image: {file_item.path}") + print(f"Error: {str(e)}") + raise e + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) + + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + + del imgs + del latent + del file_item.tensor + + # flush(garbage_collect=False) + file_item.is_latent_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() + + +class CLIPCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.clip_vision_num_unconditional_cache = 20 + self.clip_vision_unconditional_cache = [] + + def cache_clip_vision_to_disk(self: 'AiToolkitDataset'): + if not self.is_caching_clip_vision_to_disk: + return + with torch.no_grad(): + print(f"Caching clip vision for {self.dataset_path}") + + print(" - Saving clip to disk") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_clip') + + # make sure the adapter has attributes + if self.sd.adapter is None: + raise Exception("Error: must have an adapter to cache clip vision to disk") + + clip_image_processor: CLIPImageProcessor = None + if hasattr(self.sd.adapter, 'clip_image_processor'): + clip_image_processor = self.sd.adapter.clip_image_processor + + if clip_image_processor is None: + raise Exception("Error: must have a clip image processor to cache clip vision to disk") + + vision_encoder: CLIPVisionModelWithProjection = None + if hasattr(self.sd.adapter, 'image_encoder'): + vision_encoder = self.sd.adapter.image_encoder + if hasattr(self.sd.adapter, 'vision_encoder'): + vision_encoder = self.sd.adapter.vision_encoder + + if vision_encoder is None: + raise Exception("Error: must have a vision encoder to cache clip vision to disk") + + # move vision encoder to device + vision_encoder.to(self.sd.device) + + is_quad = self.sd.adapter.config.quad_image + image_encoder_path = self.sd.adapter.config.image_encoder_path + + dtype = self.sd.torch_dtype + device = self.sd.device_torch + if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero: + # just to do this, we did :) + # need more samples as it is random noise + self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache + else: + # only need one since it doesnt change + self.clip_vision_num_unconditional_cache = 1 + + # cache unconditionals + print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") + clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache') + + unconditional_paths = [] + + is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero + + for i in range(self.clip_vision_num_unconditional_cache): + hash_dict = OrderedDict([ + ("image_encoder_path", image_encoder_path), + ("is_quad", is_quad), + ("is_noise_zero", is_noise_zero), + ]) + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + + uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors') + if os.path.exists(uncond_path): + # skip it + unconditional_paths.append(uncond_path) + continue + + # generate a random image + img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size) + if is_noise_zero: + tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32) + else: + tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32) + clip_image = clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + + os.makedirs(os.path.dirname(uncond_path), exist_ok=True) + save_file(state_dict, uncond_path) + unconditional_paths.append(uncond_path) + + self.clip_vision_unconditional_cache = unconditional_paths + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'): + file_item.is_caching_clip_vision_to_disk = True + file_item.clip_vision_load_device = self.sd.device + file_item.clip_vision_is_quad = is_quad + file_item.clip_image_encoder_path = image_encoder_path + file_item.clip_vision_unconditional_paths = unconditional_paths + if file_item.has_clip_augmentations: + raise Exception("Error: clip vision caching is not supported with clip augmentations") + + embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True) + # check if it is saved to disk already + if not os.path.exists(embedding_path): + # load the image first + file_item.load_clip_image() + # add batch dimension + clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype) + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict()) + os.makedirs(os.path.dirname(embedding_path), exist_ok=True) + save_file(state_dict, embedding_path, metadata=meta) + + del clip_image + del clip_output + del file_item.clip_image_tensor + + # flush(garbage_collect=False) + file_item.is_vision_clip_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() diff --git a/toolkit/dequantize.py b/toolkit/dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..54c8ec7b29862efa11b7fc3c9dc1efc8c1d66423 --- /dev/null +++ b/toolkit/dequantize.py @@ -0,0 +1,88 @@ + + +from functools import partial +from optimum.quanto.tensor import QTensor +import torch + + +def hacked_state_dict(self, *args, **kwargs): + orig_state_dict = self.orig_state_dict(*args, **kwargs) + new_state_dict = {} + for key, value in orig_state_dict.items(): + if key.endswith("._scale"): + continue + if key.endswith(".input_scale"): + continue + if key.endswith(".output_scale"): + continue + if key.endswith("._data"): + key = key[:-6] + scale = orig_state_dict[key + "._scale"] + # scale is the original dtype + dtype = scale.dtype + scale = scale.float() + value = value.float() + dequantized = value * scale + + # handle input and output scaling if they exist + input_scale = orig_state_dict.get(key + ".input_scale") + + if input_scale is not None: + # make sure the tensor is 1.0 + if input_scale.item() != 1.0: + raise ValueError("Input scale is not 1.0, cannot dequantize") + + output_scale = orig_state_dict.get(key + ".output_scale") + + if output_scale is not None: + # make sure the tensor is 1.0 + if output_scale.item() != 1.0: + raise ValueError("Output scale is not 1.0, cannot dequantize") + + new_state_dict[key] = dequantized.to('cpu', dtype=dtype) + else: + new_state_dict[key] = value + return new_state_dict + +# hacks the state dict so we can dequantize before saving +def patch_dequantization_on_save(model): + model.orig_state_dict = model.state_dict + model.state_dict = partial(hacked_state_dict, model) + + +def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool: + """ + Convert a quantized parameter back to a regular Parameter with floating point values. + + Args: + module: The module containing the parameter to unquantize + param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias') + + Returns: + bool: True if parameter was unquantized, False if it was already unquantized + """ + + # Check if the parameter exists + if not hasattr(module, param_name): + raise AttributeError(f"Module has no parameter named '{param_name}'") + + param = getattr(module, param_name) + + # If it's not a parameter or not quantized, nothing to do + if not isinstance(param, torch.nn.Parameter): + raise TypeError(f"'{param_name}' is not a Parameter") + if not isinstance(param, QTensor): + return False + + # Convert to float tensor while preserving device and requires_grad + with torch.no_grad(): + float_tensor = param.float() + new_param = torch.nn.Parameter( + float_tensor, + requires_grad=param.requires_grad + ) + + # Replace the parameter + setattr(module, param_name, new_param) + + return True \ No newline at end of file diff --git a/toolkit/ema.py b/toolkit/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b3a7ea0974e37d783cb75e93d50659105bbb49 --- /dev/null +++ b/toolkit/ema.py @@ -0,0 +1,346 @@ +from __future__ import division +from __future__ import unicode_literals + +from typing import Iterable, Optional +import weakref +import copy +import contextlib +from toolkit.optimizers.optimizer_utils import copy_stochastic + +import torch + + +# Partially based on: +# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter] = None, + decay: float = 0.995, + use_num_updates: bool = False, + # feeds back the decat to the parameter + use_feedback: bool = False, + param_multiplier: float = 1.0 + ): + if parameters is None: + raise ValueError("parameters must be provided") + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.use_feedback = use_feedback + self.param_multiplier = param_multiplier + parameters = list(parameters) + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] + self.collected_params = None + self._is_train_mode = True + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min( + decay, + (1 + self.num_updates) / (10 + self.num_updates) + ) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + s_param_float = s_param.float() + if s_param.dtype != torch.float32: + s_param_float = s_param_float.to(torch.float32) + param_float = param + if param.dtype != torch.float32: + param_float = param_float.to(torch.float32) + tmp = (s_param_float - param_float) + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param_float.sub_(tmp) + + update_param = False + if self.use_feedback: + param_float.add_(tmp) + update_param = True + + if self.param_multiplier != 1.0: + param_float.mul_(self.param_multiplier) + update_param = True + + if s_param.dtype != torch.float32: + copy_stochastic(s_param, s_param_float) + + if update_param and param.dtype != torch.float32: + copy_stochastic(param, param_float) + + + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [ + param.clone() + for param in parameters + ] + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) + + def eval(self): + if self._is_train_mode: + with torch.no_grad(): + self.store() + self.copy_to() + self._is_train_mode = False + + def train(self): + if not self._is_train_mode: + with torch.no_grad(): + self.restore() + self._is_train_mode = True diff --git a/toolkit/embedding.py b/toolkit/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..94ba3f2f33bfa023f31da37f12c3ca4a34f0cc21 --- /dev/null +++ b/toolkit/embedding.py @@ -0,0 +1,284 @@ +import json +import os +from collections import OrderedDict + +import safetensors +import torch +from typing import TYPE_CHECKING + +from safetensors.torch import save_file + +from toolkit.metadata import get_meta_for_safetensors + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import EmbeddingConfig + + +# this is a frankenstein mix of automatic1111 and my own code + +class Embedding: + def __init__( + self, + sd: 'StableDiffusion', + embed_config: 'EmbeddingConfig', + state_dict: OrderedDict = None, + ): + self.name = embed_config.trigger + self.sd = sd + self.trigger = embed_config.trigger + self.embed_config = embed_config + self.step = 0 + # setup our embedding + # Add the placeholder token in tokenizer + placeholder_tokens = [self.embed_config.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.embed_config.tokens): + additional_tokens.append(f"{self.embed_config.trigger}_{i}") + placeholder_tokens += additional_tokens + + # handle dual tokenizer + self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer] + self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [ + self.sd.text_encoder] + + self.placeholder_token_ids = [] + self.embedding_tokens = [] + + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.embed_config.tokens} tokens to tokenizer") + + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.embed_config.tokens: + raise ValueError( + f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.embed_config.tokens: + init_token_ids = init_token_ids[:self.embed_config.tokens] + elif len(init_token_ids) < self.embed_config.tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) + + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + def restore_embeddings(self): + with torch.no_grad(): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] + weight = text_encoder.get_input_embeddings().weight + pass + + def get_trainable_params(self): + params = [] + for text_encoder in self.text_encoder_list: + params += text_encoder.get_input_embeddings().parameters() + return params + + def _get_vec(self, text_encoder_idx=0): + # should we get params instead + # create vector from token embeds + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data + # stack the tokens along batch axis adding that axis + new_vector = torch.stack( + [token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]], + dim=0 + ) + return new_vector + + def _set_vec(self, new_vector, text_encoder_idx=0): + # shape is (1, 768) for SD 1.5 for 1 token + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data + for i in range(new_vector.shape[0]): + # apply the weights to the placeholder tokens while preserving gradient + token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone() + + # make setter and getter for vec + @property + def vec(self): + return self._get_vec(0) + + @vec.setter + def vec(self, new_vector): + self._set_vec(new_vector, 0) + + @property + def vec2(self): + return self._get_vec(1) + + @vec2.setter + def vec2(self, new_vector): + self._set_vec(new_vector, 1) + + # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc + # however, on training we don't use that pipeline, so we have to do it ourselves + def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] + + replace_with = embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + + def state_dict(self): + if self.sd.is_xl: + state_dict = OrderedDict() + state_dict['clip_l'] = self.vec + state_dict['clip_g'] = self.vec2 + else: + state_dict = OrderedDict() + state_dict['emb_params'] = self.vec + + return state_dict + + def save(self, filename): + # todo check to see how to get the vector out of the embedding + + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + # todo get these + "sd_checkpoint": None, + "sd_checkpoint_name": None, + "notes": None, + } + # TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl + if filename.endswith('.pt'): + torch.save(embedding_data, filename) + elif filename.endswith('.bin'): + torch.save(embedding_data, filename) + elif filename.endswith('.safetensors'): + # save the embedding as a safetensors file + state_dict = self.state_dict() + # add all embedding data (except string_to_param), to metadata + metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) + metadata["string_to_param"] = {"*": "emb_params"} + save_meta = get_meta_for_safetensors(metadata, name=self.name) + save_file(state_dict, filename, metadata=save_meta) + + def load_embedding_from_file(self, file_path, device): + # full path + path = os.path.realpath(file_path) + filename = os.path.basename(path) + name, ext = os.path.splitext(filename) + tensors = {} + ext = ext.upper() + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + if ext in ['.BIN', '.PT']: + # todo check this + if self.sd.is_xl: + raise Exception("XL not supported yet for bin, pt") + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + # rebuild the embedding from the safetensors file if it has it + with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + tensors[k] = f.get_tensor(k) + # data = safetensors.torch.load_file(path, device="cpu") + if metadata and 'string_to_param' in metadata and 'emb_params' in tensors: + # our format + def try_json(v): + try: + return json.loads(v) + except: + return v + + data = {k: try_json(v) for k, v in metadata.items()} + data['string_to_param'] = {'*': tensors['emb_params']} + else: + # old format + data = tensors + else: + return + + if self.sd.is_xl: + self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32) + self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32) + if 'step' in data: + self.step = int(data['step']) + else: + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, + '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception( + f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + if 'step' in data: + self.step = int(data['step']) + + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/esrgan_utils.py b/toolkit/esrgan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25a8bfbada1bff84bc6bb1a49149d846c9c8c379 --- /dev/null +++ b/toolkit/esrgan_utils.py @@ -0,0 +1,51 @@ + +to_basicsr_dict = { + 'model.0.weight': 'conv_first.weight', + 'model.0.bias': 'conv_first.bias', + 'model.1.sub.23.weight': 'conv_body.weight', + 'model.1.sub.23.bias': 'conv_body.bias', + 'model.3.weight': 'conv_up1.weight', + 'model.3.bias': 'conv_up1.bias', + 'model.6.weight': 'conv_up2.weight', + 'model.6.bias': 'conv_up2.bias', + 'model.8.weight': 'conv_hr.weight', + 'model.8.bias': 'conv_hr.bias', + 'model.10.bias': 'conv_last.bias', + 'model.10.weight': 'conv_last.weight', + # 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight' +} + +def convert_state_dict_to_basicsr(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k in to_basicsr_dict: + new_state_dict[to_basicsr_dict[k]] = v + elif k.startswith('model.1.sub.'): + bsr_name = k.replace('model.1.sub.', 'body.').lower() + bsr_name = bsr_name.replace('.0.weight', '.weight') + bsr_name = bsr_name.replace('.0.bias', '.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict + + +# just matching a commonly used format +def convert_basicsr_state_dict_to_save_format(state_dict): + new_state_dict = {} + to_basicsr_dict_values = list(to_basicsr_dict.values()) + for k, v in state_dict.items(): + if k in to_basicsr_dict_values: + for key, value in to_basicsr_dict.items(): + if value == k: + new_state_dict[key] = v + + elif k.startswith('body.'): + bsr_name = k.replace('body.', 'model.1.sub.').lower() + bsr_name = bsr_name.replace('rdb', 'RDB') + bsr_name = bsr_name.replace('.weight', '.0.weight') + bsr_name = bsr_name.replace('.bias', '.0.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict diff --git a/toolkit/extension.py b/toolkit/extension.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1f38e57c7295546bb621c6f3234346f92f73a1 --- /dev/null +++ b/toolkit/extension.py @@ -0,0 +1,57 @@ +import os +import importlib +import pkgutil +from typing import List + +from toolkit.paths import TOOLKIT_ROOT + + +class Extension(object): + """Base class for extensions. + + Extensions are registered with the ExtensionManager, which is + responsible for calling the extension's load() and unload() + methods at the appropriate times. + + """ + + name: str = None + uid: str = None + + @classmethod + def get_process(cls): + # extend in subclass + pass + + +def get_all_extensions() -> List[Extension]: + extension_folders = ['extensions', 'extensions_built_in'] + + # This will hold the classes from all extension modules + all_extension_classes: List[Extension] = [] + + # Iterate over all directories (i.e., packages) in the "extensions" directory + for sub_dir in extension_folders: + extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + try: + # Import the module + module = importlib.import_module(f"{sub_dir}.{name}") + # Get the value of the AI_TOOLKIT_EXTENSIONS variable + extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) + # Check if the value is a list + if isinstance(extensions, list): + # Iterate over the list and add the classes to the main list + all_extension_classes.extend(extensions) + except ImportError as e: + print(f"Failed to import the {name} module. Error: {str(e)}") + + return all_extension_classes + + +def get_all_extensions_process_dict(): + all_extensions = get_all_extensions() + process_dict = {} + for extension in all_extensions: + process_dict[extension.uid] = extension.get_process() + return process_dict diff --git a/toolkit/guidance.py b/toolkit/guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf282046c12eba068dac7e918945135b705ef9e --- /dev/null +++ b/toolkit/guidance.py @@ -0,0 +1,693 @@ +import torch +from typing import Literal, Optional + +from toolkit.basic import value_map +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype + +GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"] + +DIFFERENTIAL_SCALER = 0.2 + + +# DIFFERENTIAL_SCALER = 0.25 + + +def get_differential_mask( + conditional_latents: torch.Tensor, + unconditional_latents: torch.Tensor, + threshold: float = 0.2, + gradient: bool = False, +): + # make a differential mask + differential_mask = torch.abs(conditional_latents - unconditional_latents) + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + differential_scaler = 1.0 / max_differential + differential_mask = differential_mask * differential_scaler + + if gradient: + # wew need to scale it to 0-1 + # differential_mask = differential_mask - differential_mask.min() + # differential_mask = differential_mask / differential_mask.max() + # add 0.2 threshold to both sides and clip + differential_mask = value_map( + differential_mask, + differential_mask.min(), + differential_mask.max(), + 0 - threshold, + 1 + threshold + ) + differential_mask = torch.clamp(differential_mask, 0.0, 1.0) + else: + + # make everything less than 0.2 be 0.0 and everything else be 1.0 + differential_mask = torch.where( + differential_mask < threshold, + torch.zeros_like(differential_mask), + torch.ones_like(differential_mask) + ) + return differential_mask + + +def get_targeted_polarity_loss( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True) + # noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True) + differential_scaler = DIFFERENTIAL_SCALER + + unconditional_diff = (unconditional_latents - conditional_latents) + unconditional_diff_noise = unconditional_diff * differential_scaler + conditional_diff = (conditional_latents - unconditional_latents) + conditional_diff_noise = conditional_diff * differential_scaler + conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False) + unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) + # + baseline_conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + baseline_unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + conditional_noise = noise + unconditional_diff_noise + unconditional_noise = noise + conditional_diff_noise + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + conditional_noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + unconditional_noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + # cat_baseline_noisy_latents = torch.cat( + # [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents], + # dim=0 + # ) + + # Disable the LoRA network so we can predict parent network knowledge without it + # sd.network.is_active = False + # sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + # baseline_prediction = sd.predict_noise( + # latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(), + # conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + # timestep=cat_timesteps, + # guidance_scale=1.0, + # **pred_kwargs # adapter residuals in here + # ).detach() + + # conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0) + + # negative_network_weights = [weight * -1.0 for weight in network_weight_list] + # positive_network_weights = [weight * 1.0 for weight in network_weight_list] + # cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + sd.unet.train() + # sd.network.is_active = True + + # sd.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + # prediction = prediction - baseline_prediction + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + # pred_pos = pred_pos - conditional_baseline_prediction + # pred_neg = pred_neg - unconditional_baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + conditional_noise.float(), + reduction="none" + ) + pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + unconditional_noise.float(), + reduction="none" + ) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + loss = pred_loss + pred_neg_loss + + loss = loss.mean() + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + +def get_direct_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + **kwargs +): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + # target_noise, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + # turn the LoRA network back on. + sd.unet.train() + # sd.network.is_active = True + + # sd.network.multiplier = network_weight_list + # do our prediction with LoRA active on the scaled guidance latents + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach() + unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) + + prediction = sd.predict_noise( + latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(), + conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=torch.cat([timesteps, timesteps]), + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) + + guidance_scale = 1.1 + guidance_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + guidance_loss = torch.nn.functional.mse_loss( + guidance_pred.float(), + noise.detach().float(), + reduction="none" + ) + if mask_multiplier is not None: + guidance_loss = guidance_loss * mask_multiplier + + guidance_loss = guidance_loss.mean([1, 2, 3]) + + guidance_loss = guidance_loss.mean() + + # loss = guidance_loss + masked_noise_loss + loss = guidance_loss + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + +# targeted +def get_targeted_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + with torch.no_grad(): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # Encode the unconditional image into latents + unconditional_noisy_latents = sd.noise_scheduler.add_noise( + unconditional_latents, + noise, + timesteps + ) + conditional_noisy_latents = sd.noise_scheduler.add_noise( + conditional_latents, + noise, + timesteps + ) + + # was_network_active = self.network.is_active + sd.network.is_active = False + sd.unet.eval() + + target_differential = unconditional_latents - conditional_latents + # scale our loss by the differential scaler + target_differential_abs = target_differential.abs() + target_differential_abs_min = \ + target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + target_differential_abs_max = \ + target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + + min_guidance = 1.0 + max_guidance = 2.0 + + differential_scaler = value_map( + target_differential_abs, + target_differential_abs_min, + target_differential_abs_max, + min_guidance, + max_guidance + ).detach() + + + # With LoRA network bypassed, predict noise to get a baseline of what the network + # wants to do with the latents + noise. Pass our target latents here for the input. + target_unconditional = sd.predict_noise( + latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + prior_prediction_loss = torch.nn.functional.mse_loss( + target_unconditional.float(), + noise.float(), + reduction="none" + ).detach().clone() + + # turn the LoRA network back on. + sd.unet.train() + sd.network.is_active = True + sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list] + + # with LoRA active, predict the noise with the scaled differential latents added. This will allow us + # the opportunity to predict the differential + noise that was added to the latents. + prediction = sd.predict_noise( + latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(), + conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(), + timestep=torch.cat([timesteps, timesteps], dim=0), + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0) + + conditional_loss = torch.nn.functional.mse_loss( + prediction_conditional.float(), + noise.float(), + reduction="none" + ) + + unconditional_loss = torch.nn.functional.mse_loss( + prediction_unconditional.float(), + noise.float(), + reduction="none" + ) + + positive_loss = torch.abs( + conditional_loss.float() - prior_prediction_loss.float(), + ) + # scale our loss by the differential scaler + positive_loss = positive_loss * differential_scaler + + positive_loss = positive_loss.mean([1, 2, 3]) + + polar_loss = torch.abs( + conditional_loss.float() - unconditional_loss.float(), + ).mean([1, 2, 3]) + + + positive_loss = positive_loss.mean() + polar_loss.mean() + + + positive_loss.backward() + # loss = positive_loss.detach() + negative_loss.detach() + loss = positive_loss.detach() + + # add a grad so other backward does not fail + loss.requires_grad_(True) + + # restore network + sd.network.multiplier = network_weight_list + + return loss + +def get_guided_loss_polarity( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + scaler=None, + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + target_pos = noise + target_neg = noise + + if sd.is_flow_matching: + # set the timesteps for flow matching as linear since we will do weighing + sd.noise_scheduler.set_train_timesteps(1000, device, linear=True) + target_pos = (noise - conditional_latents).detach() + target_neg = (noise - unconditional_latents).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + negative_network_weights = [weight * -1.0 for weight in network_weight_list] + positive_network_weights = [weight * 1.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + sd.unet.train() + sd.network.is_active = True + + sd.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + target_pos.float(), + reduction="none" + ) + # pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + target_neg.float(), + reduction="none" + ) + + loss = pred_loss + pred_neg_loss + + # if sd.is_flow_matching: + # timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach() + # loss = loss * timestep_weight + + + loss = loss.mean([1, 2, 3]) + loss = loss.mean() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + + +def get_guided_tnt( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + prior_pred: torch.Tensor = None, + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + + # turn the LoRA network back on. + sd.unet.train() + if sd.network is not None: + cat_network_weight_list = [weight for weight in network_weight_list * 2] + sd.network.multiplier = cat_network_weight_list + sd.network.is_active = True + + + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0) + + this_loss = torch.nn.functional.mse_loss( + this_prediction.float(), + noise.float(), + reduction="none" + ) + + that_loss = torch.nn.functional.mse_loss( + that_prediction.float(), + noise.float(), + reduction="none" + ) + + this_loss = this_loss.mean([1, 2, 3]) + # negative loss on that + that_loss = -that_loss.mean([1, 2, 3]) + + with torch.no_grad(): + # match that loss with this loss so it is not a negative value and same scale + that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss) + + that_loss = that_loss * that_loss_scaler * 0.01 + + loss = this_loss + that_loss + + loss = loss.mean() + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + + +# this processes all guidance losses based on the batch information +def get_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + scaler=None, + **kwargs +): + # TODO add others and process individual batch items separately + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + + if guidance_type == "targeted": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance" + return get_targeted_guidance_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + **kwargs + ) + elif guidance_type == "polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" + return get_guided_loss_polarity( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + scaler=scaler, + **kwargs + ) + elif guidance_type == "tnt": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" + return get_guided_tnt( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + prior_pred=prior_pred, + **kwargs + ) + + elif guidance_type == "targeted_polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" + return get_targeted_polarity_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + **kwargs + ) + elif guidance_type == "direct": + return get_direct_guidance_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + **kwargs + ) + else: + raise NotImplementedError(f"Guidance type {guidance_type} is not implemented") diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9f306e077ee50594f6ca6a4a005025434dcdf7 --- /dev/null +++ b/toolkit/image_utils.py @@ -0,0 +1,516 @@ +# ref https://github.com/scardine/image_size/blob/master/get_image_size.py +import atexit +import collections +import json +import os +import io +import struct +import threading +from typing import TYPE_CHECKING + +import cv2 +import numpy as np +import torch +from diffusers import AutoencoderTiny + +FILE_UNKNOWN = "Sorry, don't know how to get size for this file." + + +class UnknownImageFormat(Exception): + pass + + +types = collections.OrderedDict() +BMP = types['BMP'] = 'BMP' +GIF = types['GIF'] = 'GIF' +ICO = types['ICO'] = 'ICO' +JPEG = types['JPEG'] = 'JPEG' +PNG = types['PNG'] = 'PNG' +TIFF = types['TIFF'] = 'TIFF' + +image_fields = ['path', 'type', 'file_size', 'width', 'height'] + + +class Image(collections.namedtuple('Image', image_fields)): + + def to_str_row(self): + return ("%d\t%d\t%d\t%s\t%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + )) + + def to_str_row_verbose(self): + return ("%d\t%d\t%d\t%s\t%s\t##%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + self)) + + def to_str_json(self, indent=None): + return json.dumps(self._asdict(), indent=indent) + + +def get_image_size(file_path): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + """ + img = get_image_metadata(file_path) + return (img.width, img.height) + + +def get_image_size_from_bytesio(input, size): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + """ + img = get_image_metadata_from_bytesio(input, size) + return (img.width, img.height) + + +def get_image_metadata(file_path): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + size = os.path.getsize(file_path) + + # be explicit with open arguments - we need binary mode + with io.open(file_path, "rb") as input: + return get_image_metadata_from_bytesio(input, size, file_path) + + +def get_image_metadata_from_bytesio(input, size, file_path=None): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + height = -1 + width = -1 + data = input.read(26) + msg = " raised while trying to decode as JPEG." + + if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'): + # GIFs + imgtype = GIF + w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n') + and (data[12:16] == b'IHDR')): + # PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[16:24]) + width = int(w) + height = int(h) + elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'): + # older PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[8:16]) + width = int(w) + height = int(h) + elif (size >= 2) and data.startswith(b'\377\330'): + # JPEG + imgtype = JPEG + input.seek(0) + input.read(2) + b = input.read(1) + try: + while (b and ord(b) != 0xDA): + while (ord(b) != 0xFF): + b = input.read(1) + while (ord(b) == 0xFF): + b = input.read(1) + if (ord(b) >= 0xC0 and ord(b) <= 0xC3): + input.read(3) + h, w = struct.unpack(">HH", input.read(4)) + break + else: + input.read( + int(struct.unpack(">H", input.read(2))[0]) - 2) + b = input.read(1) + width = int(w) + height = int(h) + except struct.error: + raise UnknownImageFormat("StructError" + msg) + except ValueError: + raise UnknownImageFormat("ValueError" + msg) + except Exception as e: + raise UnknownImageFormat(e.__class__.__name__ + msg) + elif (size >= 26) and data.startswith(b'BM'): + # BMP + imgtype = 'BMP' + headersize = struct.unpack("= 40: + w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"): + # Standard TIFF, big- or little-endian + # BigTIFF and other different but TIFF-like formats are not + # supported currently + imgtype = TIFF + byteOrder = data[:2] + boChar = ">" if byteOrder == "MM" else "<" + # maps TIFF type id to size (in bytes) + # and python format char for struct + tiffTypes = { + 1: (1, boChar + "B"), # BYTE + 2: (1, boChar + "c"), # ASCII + 3: (2, boChar + "H"), # SHORT + 4: (4, boChar + "L"), # LONG + 5: (8, boChar + "LL"), # RATIONAL + 6: (1, boChar + "b"), # SBYTE + 7: (1, boChar + "c"), # UNDEFINED + 8: (2, boChar + "h"), # SSHORT + 9: (4, boChar + "l"), # SLONG + 10: (8, boChar + "ll"), # SRATIONAL + 11: (4, boChar + "f"), # FLOAT + 12: (8, boChar + "d") # DOUBLE + } + ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] + try: + countSize = 2 + input.seek(ifdOffset) + ec = input.read(countSize) + ifdEntryCount = struct.unpack(boChar + "H", ec)[0] + # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4 + # bytes: value offset + ifdEntrySize = 12 + for i in range(ifdEntryCount): + entryOffset = ifdOffset + countSize + i * ifdEntrySize + input.seek(entryOffset) + tag = input.read(2) + tag = struct.unpack(boChar + "H", tag)[0] + if (tag == 256 or tag == 257): + # if type indicates that value fits into 4 bytes, value + # offset is not an offset but value itself + type = input.read(2) + type = struct.unpack(boChar + "H", type)[0] + if type not in tiffTypes: + raise UnknownImageFormat( + "Unkown TIFF field type:" + + str(type)) + typeSize = tiffTypes[type][0] + typeChar = tiffTypes[type][1] + input.seek(entryOffset + 8) + value = input.read(typeSize) + value = int(struct.unpack(typeChar, value)[0]) + if tag == 256: + width = value + else: + height = value + if width > -1 and height > -1: + break + except Exception as e: + raise UnknownImageFormat(str(e)) + elif size >= 2: + # see http://en.wikipedia.org/wiki/ICO_(file_format) + imgtype = 'ICO' + input.seek(0) + reserved = input.read(2) + if 0 != struct.unpack(" 1: + import warnings + warnings.warn("ICO File contains more than one image") + # http://msdn.microsoft.com/en-us/library/ms997538.aspx + w = input.read(1) + h = input.read(1) + width = ord(w) + height = ord(h) + else: + raise UnknownImageFormat(FILE_UNKNOWN) + + return Image(path=file_path, + type=imgtype, + file_size=size, + width=width, + height=height) + + +import unittest + + +class Test_get_image_size(unittest.TestCase): + data = [{ + 'path': 'lookmanodeps.png', + 'width': 251, + 'height': 208, + 'file_size': 22228, + 'type': 'PNG'}] + + def setUp(self): + pass + + def test_get_image_size_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_size_from_bytesio(fp, sz) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def test_get_image_metadata_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_metadata_from_bytesio(fp, sz) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), None if field == 'path' else img[field]) + + def test_get_image_metadata(self): + img = self.data[0] + output = get_image_metadata(img['path']) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), img[field]) + + def test_get_image_metadata__ENOENT_OSError(self): + with self.assertRaises(OSError): + get_image_metadata('THIS_DOES_NOT_EXIST') + + def test_get_image_metadata__not_an_image_UnknownImageFormat(self): + with self.assertRaises(UnknownImageFormat): + get_image_metadata('README.rst') + + def test_get_image_size(self): + img = self.data[0] + output = get_image_size(img['path']) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def tearDown(self): + pass + + +def main(argv=None): + """ + Print image metadata fields for the given file path. + + Keyword Arguments: + argv (list): commandline arguments (e.g. sys.argv[1:]) + Returns: + int: zero for OK + """ + import logging + import optparse + import sys + + prs = optparse.OptionParser( + usage="%prog [-v|--verbose] [--json|--json-indent] []", + description="Print metadata for the given image paths " + "(without image library bindings).") + + prs.add_option('--json', + dest='json', + action='store_true') + prs.add_option('--json-indent', + dest='json_indent', + action='store_true') + + prs.add_option('-v', '--verbose', + dest='verbose', + action='store_true', ) + prs.add_option('-q', '--quiet', + dest='quiet', + action='store_true', ) + prs.add_option('-t', '--test', + dest='run_tests', + action='store_true', ) + + argv = list(argv) if argv is not None else sys.argv[1:] + (opts, args) = prs.parse_args(args=argv) + loglevel = logging.INFO + if opts.verbose: + loglevel = logging.DEBUG + elif opts.quiet: + loglevel = logging.ERROR + logging.basicConfig(level=loglevel) + log = logging.getLogger() + log.debug('argv: %r', argv) + log.debug('opts: %r', opts) + log.debug('args: %r', args) + + if opts.run_tests: + import sys + sys.argv = [sys.argv[0]] + args + import unittest + return unittest.main() + + output_func = Image.to_str_row + if opts.json_indent: + import functools + output_func = functools.partial(Image.to_str_json, indent=2) + elif opts.json: + output_func = Image.to_str_json + elif opts.verbose: + output_func = Image.to_str_row_verbose + + EX_OK = 0 + EX_NOT_OK = 2 + + if len(args) < 1: + prs.print_help() + print('') + prs.error("You must specify one or more paths to image files") + + errors = [] + for path_arg in args: + try: + img = get_image_metadata(path_arg) + print(output_func(img)) + except KeyboardInterrupt: + raise + except OSError as e: + log.error((path_arg, e)) + errors.append((path_arg, e)) + except Exception as e: + log.exception(e) + errors.append((path_arg, e)) + pass + if len(errors): + import pprint + print("ERRORS", file=sys.stderr) + print("======", file=sys.stderr) + print(pprint.pformat(errors, indent=2), file=sys.stderr) + return EX_NOT_OK + return EX_OK + + +is_window_shown = False +display_lock = threading.Lock() +current_img = None +update_event = threading.Event() + +def update_image(img, name): + global current_img + with display_lock: + current_img = (img, name) + update_event.set() + +def display_image_in_thread(): + global is_window_shown + + def display_img(): + global current_img + while True: + update_event.wait() + with display_lock: + if current_img: + img, name = current_img + cv2.imshow(name, img) + current_img = None + update_event.clear() + if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop + cv2.destroyAllWindows() + print('\nESC pressed, stopping') + break + + if not is_window_shown: + is_window_shown = True + threading.Thread(target=display_img, daemon=True).start() + + +def show_img(img, name='AI Toolkit'): + img = np.clip(img, 0, 255).astype(np.uint8) + update_image(img[:, :, ::-1], name) + if not is_window_shown: + display_image_in_thread() + + +def show_tensors(imgs: torch.Tensor, name='AI Toolkit'): + if len(imgs.shape) == 4: + img_list = torch.chunk(imgs, imgs.shape[0], dim=0) + else: + img_list = [imgs] + + img = torch.cat(img_list, dim=3) + img = img / 2 + 0.5 + img_numpy = img.to(torch.float32).detach().cpu().numpy() + img_numpy = np.clip(img_numpy, 0, 1) * 255 + img_numpy = img_numpy.transpose(0, 2, 3, 1) + img_numpy = img_numpy.astype(np.uint8) + + show_img(img_numpy[0], name=name) + + +def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'): + if vae.device == 'cpu': + vae.to(latents.device) + latents = latents / vae.config['scaling_factor'] + imgs = vae.decode(latents).sample + show_tensors(imgs, name=name) + + +def on_exit(): + if is_window_shown: + cv2.destroyAllWindows() + + +def reduce_contrast(tensor, factor): + # Ensure factor is between 0 and 1 + factor = max(0, min(factor, 1)) + + # Calculate the mean of the tensor + mean = torch.mean(tensor) + + # Reduce contrast + adjusted_tensor = (tensor - mean) * factor + mean + + # Clip values to ensure they stay within -1 to 1 range + return torch.clamp(adjusted_tensor, -1.0, 1.0) + +atexit.register(on_exit) + +if __name__ == "__main__": + import sys + + sys.exit(main(argv=sys.argv[1:])) diff --git a/toolkit/inversion_utils.py b/toolkit/inversion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51a61d83bd8d62efbfa2c3b99069f7c6bb0f81ca --- /dev/null +++ b/toolkit/inversion_utils.py @@ -0,0 +1,410 @@ +# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py + +import torch +import os +from tqdm import tqdm + +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from toolkit.stable_diffusion_model import StableDiffusion + + +def mu_tilde(model, xt, x0, timestep): + "mu_tilde(x_t, x_0) DDPM paper eq. 7" + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + alpha_t = model.scheduler.alphas[timestep] + beta_t = 1 - alpha_t + alpha_bar = model.scheduler.alphas_cumprod[timestep] + return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + ( + (alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt + + +def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50): + """ + Samples from P(x_1:T|x_0) + """ + # torch.manual_seed(43256465436) + alpha_bar = sd.noise_scheduler.alphas_cumprod + sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5 + alphas = sd.noise_scheduler.alphas + betas = 1 - alphas + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16) + for t in reversed(timesteps): + idx = t_to_idx[int(t)] + xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t] + xts = torch.cat([xts, sample], dim=0) + + return xts + + +def encode_text(model, prompts): + text_input = model.tokenizer( + prompts, + padding="max_length", + max_length=model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + with torch.no_grad(): + text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] + return text_encoding + + +def forward_step(sd: StableDiffusion, model_output, timestep, sample): + next_timestep = min( + sd.noise_scheduler.config['num_train_timesteps'] - 2, + timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + ) + + # 2. compute alphas, betas + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 5. TODO: simple noising implementation + next_sample = sd.noise_scheduler.add_noise( + pred_original_sample, + model_output, + torch.LongTensor([next_timestep])) + return next_sample + + +def get_variance(sd: StableDiffusion, timestep): # , prev_timestep): + prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + return variance + + +def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor): + VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1) + if sd.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + +def inversion_forward_process( + sd: StableDiffusion, + sample: torch.Tensor, + conditional_embeddings: PromptEmbeds, + unconditional_embeddings: PromptEmbeds, + etas=None, + prog_bar=False, + cfg_scale=3.5, + num_inference_steps=50, eps=None +): + current_num_timesteps = len(sd.noise_scheduler.timesteps) + sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device) + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size + # ) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + if etas is None or (type(etas) in [int, float] and etas == 0): + eta_is_zero = True + zs = None + else: + eta_is_zero = False + if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps + xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps) + alpha_bar = sd.noise_scheduler.alphas_cumprod + zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16) + + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + noisy_sample = sample + op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) + + for timestep in op: + idx = t_to_idx[int(timestep)] + # 1. predict noise residual + if not eta_is_zero: + noisy_sample = xts[idx][None] + + added_cond_kwargs = {} + + with torch.no_grad(): + text_embeddings = train_tools.concat_prompt_embeddings( + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + 1, # batch size + ) + if sd.is_xl: + add_time_ids = get_time_ids_from_latents(sd, noisy_sample) + # add extra for cfg + add_time_ids = torch.cat( + [add_time_ids] * 2, dim=0 + ) + + added_cond_kwargs = { + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + # double up for cfg + latent_model_input = torch.cat( + [noisy_sample] * 2, dim=0 + ) + + noise_pred = sd.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + # out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding) + # cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings) + + noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) + + if eta_is_zero: + # 2. compute more noisy image and set x_t -> x_t+1 + noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample) + xts = None + + else: + xtm1 = xts[idx + 1][None] + # pred of x0 + pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[ + timestep] ** 0.5 + + # direction to xt + prev_timestep = timestep - sd.noise_scheduler.config[ + 'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + + variance = get_variance(sd, timestep) + pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred + + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) + zs[idx] = z + + # correction to avoid error accumulation + xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z + xts[idx + 1] = xtm1 + + if not zs is None: + zs[-1] = torch.zeros_like(zs[-1]) + + # restore timesteps + sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device) + + return noisy_sample, zs, xts + + +# +# def inversion_forward_process( +# model, +# sample, +# etas=None, +# prog_bar=False, +# prompt="", +# cfg_scale=3.5, +# num_inference_steps=50, eps=None +# ): +# if not prompt == "": +# text_embeddings = encode_text(model, prompt) +# uncond_embedding = encode_text(model, "") +# timesteps = model.scheduler.timesteps.to(model.device) +# variance_noise_shape = ( +# num_inference_steps, +# model.unet.in_channels, +# model.unet.sample_size, +# model.unet.sample_size) +# if etas is None or (type(etas) in [int, float] and etas == 0): +# eta_is_zero = True +# zs = None +# else: +# eta_is_zero = False +# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps +# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps) +# alpha_bar = model.scheduler.alphas_cumprod +# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16) +# +# t_to_idx = {int(v): k for k, v in enumerate(timesteps)} +# noisy_sample = sample +# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) +# +# for t in op: +# idx = t_to_idx[int(t)] +# # 1. predict noise residual +# if not eta_is_zero: +# noisy_sample = xts[idx][None] +# +# with torch.no_grad(): +# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding) +# if not prompt == "": +# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings) +# +# if not prompt == "": +# ## classifier free guidance +# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) +# else: +# noise_pred = out.sample +# +# if eta_is_zero: +# # 2. compute more noisy image and set x_t -> x_t+1 +# noisy_sample = forward_step(model, noise_pred, t, noisy_sample) +# +# else: +# xtm1 = xts[idx + 1][None] +# # pred of x0 +# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 +# +# # direction to xt +# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps +# alpha_prod_t_prev = model.scheduler.alphas_cumprod[ +# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod +# +# variance = get_variance(model, t) +# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred +# +# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction +# +# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) +# zs[idx] = z +# +# # correction to avoid error accumulation +# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z +# xts[idx + 1] = xtm1 +# +# if not zs is None: +# zs[-1] = torch.zeros_like(zs[-1]) +# +# return noisy_sample, zs, xts + + +def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None): + # 1. get previous step value (=t-1) + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + # 2. compute alphas, betas + alpha_prod_t = model.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self.scheduler._get_variance(timestep, prev_timestep) + variance = get_variance(model, timestep) # , prev_timestep) + std_dev_t = eta * variance ** (0.5) + # Take care of asymetric reverse process (asyrp) + model_output_direction = model_output + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction + pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # 8. Add noice if eta > 0 + if eta > 0: + if variance_noise is None: + variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16) + sigma_z = eta * variance ** (0.5) * variance_noise + prev_sample = prev_sample + sigma_z + + return prev_sample + + +def inversion_reverse_process( + model, + xT, + etas=0, + prompts="", + cfg_scales=None, + prog_bar=False, + zs=None, + controller=None, + asyrp=False): + batch_size = len(prompts) + + cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16) + + text_embeddings = encode_text(model, prompts) + uncond_embedding = encode_text(model, [""] * batch_size) + + if etas is None: etas = 0 + if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps + assert len(etas) == model.scheduler.num_inference_steps + timesteps = model.scheduler.timesteps.to(model.device) + + xt = xT.expand(batch_size, -1, -1, -1) + op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:] + + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} + + for t in op: + idx = t_to_idx[int(t)] + ## Unconditional embedding + with torch.no_grad(): + uncond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=uncond_embedding) + + ## Conditional embedding + if prompts: + with torch.no_grad(): + cond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=text_embeddings) + + z = zs[idx] if not zs is None else None + z = z.expand(batch_size, -1, -1, -1) + if prompts: + ## classifier free guidance + noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample) + else: + noise_pred = uncond_out.sample + # 2. compute less noisy image and set x_t -> x_t-1 + xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z) + if controller is not None: + xt = controller.step_callback(xt) + return xt, zs diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4821e968728ee1102c826d06de5022713ca82256 --- /dev/null +++ b/toolkit/ip_adapter.py @@ -0,0 +1,1337 @@ +import random + +import torch +import sys + +from PIL import Image +from diffusers import Transformer2DModel +from torch import nn +from torch.nn import Parameter +from torch.nn.modules.module import T +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.zipper_resampler import ZipperResampler +from toolkit.paths import REPOS_ROOT +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype +from toolkit.util.inverse_cfg import inverse_classifier_guidance + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref +from diffusers import FluxTransformer2DModel + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel, + AutoImageProcessor, + ConvNextModel, + ConvNextV2ForImageClassification, + ConvNextForImageClassification, + ConvNextImageProcessor +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification + +# gradient checkpointing +from torch.utils.checkpoint import checkpoint + +import torch.nn.functional as F + + +class MLPProjModelClipFace(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.norm = torch.nn.LayerNorm(id_embeddings_dim) + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), + ) + # Initialize the last linear layer weights near zero + torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01) + torch.nn.init.zeros_(self.proj[2].bias) + # # Custom initialization for LayerNorm to output near zero + # torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero + # torch.nn.init.zeros_(self.norm.bias) # Bias to zero + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return x + + +class CustomIPAttentionProcessor(IPAttnProcessor2_0): + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False): + super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + if train_scaler: + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if is_active: + # since we are removing tokens, we need to adjust the sequence length + sequence_length = sequence_length - self.num_tokens + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + # will be none if disabled + if not is_active: + ip_hidden_states = None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + try: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + except Exception as e: + print(e) + raise e + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # will be none if disabled + if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + # this ensures that the ip_scaler is not changed when we load the model + # def _apply(self, fn): + # if hasattr(self, "ip_scaler"): + # # Overriding the _apply method to prevent the special_parameter from changing dtype + # self.ip_scaler = fn(self.ip_scaler) + # # Temporarily set the special_parameter to None to exclude it from default _apply processing + # ip_scaler = self.ip_scaler + # self.ip_scaler = None + # super(CustomIPAttentionProcessor, self)._apply(fn) + # # Restore the special_parameter after the default _apply processing + # self.ip_scaler = ip_scaler + # return self + # else: + # return super(CustomIPAttentionProcessor, self)._apply(fn) + + +class CustomIPFluxAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, + full_token_scaler=False): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + self.num_tokens = num_tokens + if train_scaler: + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + is_active = self.adapter_ref().is_active + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # begin ip adapter + if not is_active: + ip_hidden_states = None + else: + # get ip hidden states. Should be stored + ip_hidden_states = self.adapter_ref().last_conditional + # add unconditional to front if it exists + if ip_hidden_states.shape[0] * 2 == batch_size: + if self.adapter_ref().last_unconditional is None: + raise ValueError("Unconditional is None but should not be") + ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0) + + if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states + # end ip adapter + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + +# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.preprocessor: Optional[CLIPImagePreProcessor] = None + self.input_size = 224 + self.clip_noise_zero = True + self.unconditional: torch.Tensor = None + + self.last_conditional: torch.Tensor = None + self.last_unconditional: torch.Tensor = None + + self.additional_loss = None + if self.config.image_encoder_arch.startswith("clip"): + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SiglipImageProcessor() + self.image_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit': + try: + self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = ViTFeatureExtractor() + self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SAFEImageProcessor() + self.image_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.safe_tokens, + num_vectors=sd.unet.config['cross_attention_dim'], + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.image_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnextv2': + try: + self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=512, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + self.image_encoder = ConvNextV2ForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit-hybrid': + try: + self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ViTHybridImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.image_encoder = ViTHybridForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + if not self.config.train_image_encoder: + # compile it + print('Compiling image encoder') + #torch.compile(self.image_encoder, fullgraph=True) + + self.input_size = self.image_encoder.config.image_size + + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + if 'height' in self.clip_image_processor.size: + self.clip_image_processor.size['height'] = preprocessor_input_size + self.clip_image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.clip_image_processor, 'crop_size'): + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + + if self.config.image_encoder_arch == 'clip+': + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 4 + + # update the preprocessor so images come in at the right size + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.image_encoder.config.image_size, + ) + if not self.config.image_encoder_arch == 'safe': + if 'height' in self.clip_image_processor.size: + self.input_size = self.clip_image_processor.size['height'] + elif hasattr(self.clip_image_processor, 'crop_size'): + self.input_size = self.clip_image_processor.crop_size['height'] + elif 'shortest_edge' in self.clip_image_processor.size.keys(): + self.input_size = self.clip_image_processor.size['shortest_edge'] + else: + raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") + self.current_scale = 1.0 + self.is_active = True + is_pixart = sd.is_pixart + is_flux = sd.is_flux + if adapter_config.type == 'ip': + # ip-adapter + image_proj_model = ImageProjModel( + cross_attention_dim=sd.unet.config['cross_attention_dim'], + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.config.num_tokens, # usually 4 + ) + elif adapter_config.type == 'ip_clip_face': + cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim'] + image_proj_model = MLPProjModelClipFace( + cross_attention_dim=cross_attn_dim, + id_embeddings_dim=self.image_encoder.config.projection_dim, + num_tokens=self.config.num_tokens, # usually 4 + ) + elif adapter_config.type == 'ip+': + heads = 12 if not sd.is_xl else 20 + if is_flux: + dim = 1280 + else: + dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( + 'convnext') else \ + self.image_encoder.config.hidden_sizes[-1] + + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + max_seq_len = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + max_seq_len = int( + image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if is_pixart: + heads = 20 + dim = 1280 + output_dim = 4096 + elif is_flux: + heads = 20 + dim = 1280 + output_dim = 3072 + else: + output_dim = sd.unet.config['cross_attention_dim'] + + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + + # ip-adapter-plus + image_proj_model = Resampler( + dim=dim, + depth=4, + dim_head=64, + heads=heads, + num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len, + embedding_dim=embedding_dim, + max_seq_len=max_seq_len, + output_dim=output_dim, + ff_mult=4 + ) + elif adapter_config.type == 'ipz': + dim = sd.unet.config['cross_attention_dim'] + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.target_hidden_size + + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + + is_conv_next = self.config.image_encoder_arch.startswith('convnext') + + out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens + # ip-adapter-plus + image_proj_model = ZipperResampler( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=embedding_dim, + hidden_tokens=in_tokens, + # num_blocks=1 if not is_conv_next else 2, + num_blocks=1 if not is_conv_next else 2, + is_conv_input=is_conv_next + ) + elif adapter_config.type == 'ilora': + # we apply the clip encodings to the LoRA + image_proj_model = None + else: + raise ValueError(f"unknown adapter type: {adapter_config.type}") + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + elif is_flux: + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn") + + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + attn_processor_names = [] + + blocks = [] + transformer_blocks = [] + for name in attn_processor_keys: + name_split = name.split(".") + block_name = f"{name_split[0]}.{name_split[1]}" + transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1 + if transformer_idx >= 0: + transformer_name = ".".join(name_split[:2]) + transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2]) + if transformer_name not in transformer_blocks: + transformer_blocks.append(transformer_name) + + + if block_name not in blocks: + blocks.append(block_name) + if is_flux: + cross_attention_dim = None + else: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer") or name.startswith("single_transformer"): + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None and not is_flux: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + + # if quantized, we need to scale the weights + if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux: + # is quantized + + k_weight = torch.randn(hidden_size, hidden_size) * 0.01 + v_weight = torch.randn(hidden_size, hidden_size) * 0.01 + k_weight = k_weight.to(self.sd_ref().torch_dtype) + v_weight = v_weight.to(self.sd_ref().torch_dtype) + else: + k_weight = unet_sd[layer_name + ".to_k.weight"] + v_weight = unet_sd[layer_name + ".to_v.weight"] + + weights = { + "to_k_ip.weight": k_weight, + "to_v_ip.weight": v_weight + } + + if is_flux: + attn_procs[name] = CustomIPFluxAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + full_token_scaler=False + ) + else: + attn_procs[name] = CustomIPAttentionProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler + full_token_scaler=False + ) + if self.sd_ref().is_pixart or self.sd_ref().is_flux: + # pixart is much more sensitive + weights = { + "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01, + "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01, + } + + attn_procs[name].load_state_dict(weights, strict=False) + attn_processor_names.append(name) + print(f"Attn Processors") + print(attn_processor_names) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + elif self.sd_ref().is_flux: + # we have to set them ourselves + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.image_proj_model = image_proj_model + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) + + if self.config.train_image_encoder: + self.image_encoder.train() + self.image_encoder.requires_grad_(True) + + # premake a unconditional + zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16) + self.unconditional = self.clip_image_processor( + images=zerod, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) + return self + + # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): + # self.image_proj_model.load_state_dict(state_dict["image_proj"]) + # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + # ip_layers.load_state_dict(state_dict["ip_adapter"]) + # if self.config.train_image_encoder and 'image_encoder' in state_dict: + # self.image_encoder.load_state_dict(state_dict["image_encoder"]) + # if self.preprocessor is not None and 'preprocessor' in state_dict: + # self.preprocessor.load_state_dict(state_dict["preprocessor"]) + + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.image_encoder.state_dict() + if self.config.train_scaler: + state_dict["ip_scale"] = self.adapter_modules.state_dict() + # remove items that are not scalers + for key in list(state_dict["ip_scale"].keys()): + if not key.endswith("ip_scaler"): + del state_dict["ip_scale"][key] + return state_dict + + state_dict["image_proj"] = self.image_proj_model.state_dict() + state_dict["ip_adapter"] = self.adapter_modules.state_dict() + # handle merge scaler training + if self.config.merge_scaler: + for key in list(state_dict["ip_adapter"].keys()): + if key.endswith("ip_scaler"): + # merge in the scaler so we dont have to save it and it will be compatible with other ip adapters + scale = state_dict["ip_adapter"][key].clone() + + key_start = key.split(".")[-2] + # reshape to (1, 1) + scale = scale.view(1, 1) + del state_dict["ip_adapter"][key] + # find the to_k_ip and to_v_ip keys + for key2 in list(state_dict["ip_adapter"].keys()): + if key2.endswith(f"{key_start}.to_k_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + if key2.endswith(f"{key_start}.to_v_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + + if self.config.train_image_encoder: + state_dict["image_encoder"] = self.image_encoder.state_dict() + if self.preprocessor is not None: + state_dict["preprocessor"] = self.preprocessor.state_dict() + return state_dict + + def get_scale(self): + return self.current_scale + + def set_scale(self, scale): + self.current_scale = scale + if not self.sd_ref().is_pixart and not self.sd_ref().is_flux: + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, CustomIPAttentionProcessor): + attn_processor.scale = scale + + # @torch.no_grad() + # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], + # drop=False) -> torch.Tensor: + # # todo: add support for sdxl + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + # if drop: + # clip_image = clip_image * 0 + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + # return clip_image_embeds + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) + return self + + def parse_clip_image_embeds_from_cache( + self, + image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + quad_count=4, + ): + with torch.no_grad(): + device = self.sd_ref().unet.device + clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0) + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + return clip_image_embeds + + def get_empty_clip_image(self, batch_size: int) -> torch.Tensor: + with torch.no_grad(): + tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + return clip_image.detach() + + def get_clip_image_embeds_from_tensors( + self, + tensors_0_1: torch.Tensor, + drop=False, + is_training=False, + has_been_preprocessed=False, + quad_count=4, + cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative + ) -> torch.Tensor: + if self.sd_ref().unet.device != self.device: + self.to(self.sd_ref().unet.device) + if self.sd_ref().unet.device != self.image_encoder.device: + self.to(self.sd_ref().unet.device) + if not self.config.train: + is_training = False + uncond_clip = None + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + # if drop: + # clip_image = clip_image * 0 + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.image_encoder.train() + clip_image = clip_image.requires_grad_(True) + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, + output_hidden_states=True + ) + else: + self.image_encoder.eval() + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + clip_image_embeds = clip_output.image_embeds + + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(clip_image_embeds, p=2) + clip_image_embeds = clip_image_embeds / l2_norm + + if self.config.image_encoder_arch.startswith('convnext'): + # flatten the width height layers to make the token space + clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1) + # rearrange to (batch, tokens, size) + clip_image_embeds = clip_image_embeds.permute(0, 2, 1) + + # apply unconditional if doing cfg on embeds + with torch.no_grad(): + if cfg_embed_strength is not None: + uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = uncond_clip.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + uncond_clip = torch.cat(to_cat, dim=0).detach() + uncond_clip_output = self.image_encoder( + uncond_clip, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1] + else: + uncond_clip_output_embeds = uncond_clip_output.image_embeds + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(uncond_clip_output_embeds, p=2) + uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm + + uncond_clip_output_embeds = uncond_clip_output_embeds.detach() + + + # apply inverse cfg + clip_image_embeds = inverse_classifier_guidance( + clip_image_embeds, + uncond_clip_output_embeds, + cfg_embed_strength + ) + + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + if self.config.train_image_encoder and is_training: + # perform a loss across all chunks this will teach the vision encoder to + # identify similarities in our pairs of images and ignore things that do not make them similar + num_losses = 0 + total_loss = None + for chunk in chunks: + for chunk2 in chunks: + if chunk is not chunk2: + loss = F.mse_loss(chunk, chunk2) + if total_loss is None: + total_loss = loss + else: + total_loss = total_loss + loss + num_losses += 1 + if total_loss is not None: + total_loss = total_loss / num_losses + total_loss = total_loss * 1e-2 + if self.additional_loss is not None: + total_loss = total_loss + self.additional_loss + self.additional_loss = total_loss + + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + clip_image_embeds = clip_image_embeds.detach() + + return clip_image_embeds + + # use drop for prompt dropout, or negatives + def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + if self.sd_ref().is_flux: + # do not attach to text embeds for flux, we will save and grab them as it messes + # with the RoPE to have them in the same tensor + if is_unconditional: + self.last_unconditional = image_prompt_embeds + else: + self.last_conditional = image_prompt_embeds + else: + embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) + return embeddings + + def train(self: T, mode: bool = True) -> T: + if self.config.train_image_encoder: + self.image_encoder.train(mode) + if not self.config.train_only_image_encoder: + for attn_processor in self.adapter_modules: + attn_processor.train(mode) + if self.image_proj_model is not None: + self.image_proj_model.train(mode) + return super().train(mode) + + def get_parameter_groups(self, adapter_lr): + param_groups = [] + # when training just scaler, we do not train anything else + if not self.config.train_scaler: + param_groups.append({ + "params": list(self.get_non_scaler_parameters()), + "lr": adapter_lr, + }) + if self.config.train_scaler or self.config.merge_scaler: + scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr + param_groups.append({ + "params": list(self.get_scaler_parameters()), + "lr": scaler_lr, + }) + return param_groups + + def get_scaler_parameters(self): + # only get the scalera from the adapter modules + for attn_processor in self.adapter_modules: + # only get the scaler + # check if it has ip_scaler attribute + if hasattr(attn_processor, "ip_scaler"): + scaler_param = attn_processor.ip_scaler + yield scaler_param + + def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + if self.config.train_only_image_encoder_positional_embedding: + yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse) + else: + yield from self.image_encoder.parameters(recurse) + return + if self.config.train_scaler: + # no params + return + + for attn_processor in self.adapter_modules: + if self.config.train_scaler or self.config.merge_scaler: + # todo remove scaler + if hasattr(attn_processor, "to_k_ip"): + # yield the linear layer + yield from attn_processor.to_k_ip.parameters(recurse) + if hasattr(attn_processor, "to_v_ip"): + # yield the linear layer + yield from attn_processor.to_v_ip.parameters(recurse) + else: + yield from attn_processor.parameters(recurse) + yield from self.image_proj_model.parameters(recurse) + if self.config.train_image_encoder: + yield from self.image_encoder.parameters(recurse) + if self.preprocessor is not None: + yield from self.preprocessor.parameters(recurse) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield from self.get_non_scaler_parameters(recurse) + if self.config.train_scaler or self.config.merge_scaler: + yield from self.get_scaler_parameters() + + def merge_in_weights(self, state_dict: Mapping[str, Any]): + # merge in img_proj weights + current_img_proj_state_dict = self.image_proj_model.state_dict() + for key, value in state_dict["image_proj"].items(): + if key in current_img_proj_state_dict: + current_shape = current_img_proj_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_img_proj_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + except RuntimeError as e: + print(e) + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if len(current_shape) == 1: + current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[0], + :current_shape[1]] + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + else: + current_img_proj_state_dict[key] = value + self.image_proj_model.load_state_dict(current_img_proj_state_dict) + + # merge in ip adapter weights + current_ip_adapter_state_dict = self.adapter_modules.state_dict() + for key, value in state_dict["ip_adapter"].items(): + if key in current_ip_adapter_state_dict: + current_shape = current_ip_adapter_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_ip_adapter_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + except RuntimeError as e: + print(e) + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if (len(current_shape) == 1): + current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif (len(current_shape) == 2): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[ + 0], + :current_shape[ + 1]] + elif (len(current_shape) == 3): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif (len(current_shape) == 4): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + + else: + current_ip_adapter_state_dict[key] = value + self.adapter_modules.load_state_dict(current_ip_adapter_state_dict) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + if self.config.train_scaler and 'ip_scale' in state_dict: + self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False) + if 'ip_adapter' in state_dict: + try: + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + except Exception as e: + print(e) + print("could not load ip adapter weights, trying to merge in weights") + self.merge_in_weights(state_dict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + if self.preprocessor is not None and 'preprocessor' in state_dict: + self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) + + if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict: + # we are loading pure clip weights. + self.image_encoder.load_state_dict(state_dict, strict=strict) + + def enable_gradient_checkpointing(self): + if hasattr(self.image_encoder, "enable_gradient_checkpointing"): + self.image_encoder.enable_gradient_checkpointing() + elif hasattr(self.image_encoder, 'gradient_checkpointing'): + self.image_encoder.gradient_checkpointing = True diff --git a/toolkit/job.py b/toolkit/job.py new file mode 100644 index 0000000000000000000000000000000000000000..dc274fb798efb5780056bd2134eb2940a608a98c --- /dev/null +++ b/toolkit/job.py @@ -0,0 +1,44 @@ +from typing import Union, OrderedDict + +from toolkit.config import get_config + + +def get_job( + config_path: Union[str, dict, OrderedDict], + name=None +): + config = get_config(config_path, name) + if not config['job']: + raise ValueError('config file is invalid. Missing "job" key') + + job = config['job'] + if job == 'extract': + from jobs import ExtractJob + return ExtractJob(config) + if job == 'train': + from jobs import TrainJob + return TrainJob(config) + if job == 'mod': + from jobs import ModJob + return ModJob(config) + if job == 'generate': + from jobs import GenerateJob + return GenerateJob(config) + if job == 'extension': + from jobs import ExtensionJob + return ExtensionJob(config) + + # elif job == 'train': + # from jobs import TrainJob + # return TrainJob(config) + else: + raise ValueError(f'Unknown job type {job}') + + +def run_job( + config: Union[str, dict, OrderedDict], + name=None +): + job = get_job(config, name) + job.run() + job.cleanup() diff --git a/toolkit/keymaps/stable_diffusion_refiner.json b/toolkit/keymaps/stable_diffusion_refiner.json new file mode 100644 index 0000000000000000000000000000000000000000..4c7525d8804da9ec92b7f87bc01741d4372ac83d --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_refiner.json @@ -0,0 +1,3498 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.0.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.0.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..02e3ebb921777760664f8073d2131f2503b60d81 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9d9f0fd82268e59d252443653a6e3c30ef4d30fb0b418fadadcebc2572608aa +size 3277018 diff --git a/toolkit/keymaps/stable_diffusion_refiner_unmatched.json b/toolkit/keymaps/stable_diffusion_refiner_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..cb5aba0a543c8ad50094abb3f99e266336908aaa --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_refiner_unmatched.json @@ -0,0 +1,27 @@ +{ + "ldm": { + "conditioner.embedders.0.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "conditioner.embedders.0.model.text_projection": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + }, + "diffusers": { + "te1_text_projection.weight": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_sd1.json b/toolkit/keymaps/stable_diffusion_sd1.json new file mode 100644 index 0000000000000000000000000000000000000000..8f04f753ac6656fdc2a2d44d8d07ebc7db184689 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sd1.json @@ -0,0 +1,1234 @@ +{ + "ldm_diffusers_keymap": { + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "te_text_model.embeddings.position_embedding.weight", + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te_text_model.encoder.layers.0.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te_text_model.encoder.layers.0.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te_text_model.encoder.layers.0.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te_text_model.encoder.layers.0.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te_text_model.encoder.layers.0.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te_text_model.encoder.layers.0.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te_text_model.encoder.layers.1.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te_text_model.encoder.layers.1.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te_text_model.encoder.layers.1.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te_text_model.encoder.layers.1.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te_text_model.encoder.layers.1.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te_text_model.encoder.layers.1.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te_text_model.encoder.layers.10.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te_text_model.encoder.layers.10.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te_text_model.encoder.layers.10.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te_text_model.encoder.layers.10.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te_text_model.encoder.layers.10.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te_text_model.encoder.layers.10.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te_text_model.encoder.layers.11.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te_text_model.encoder.layers.11.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te_text_model.encoder.layers.11.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te_text_model.encoder.layers.11.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te_text_model.encoder.layers.11.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te_text_model.encoder.layers.11.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te_text_model.encoder.layers.2.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te_text_model.encoder.layers.2.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te_text_model.encoder.layers.2.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te_text_model.encoder.layers.2.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te_text_model.encoder.layers.2.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te_text_model.encoder.layers.2.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te_text_model.encoder.layers.3.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te_text_model.encoder.layers.3.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te_text_model.encoder.layers.3.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te_text_model.encoder.layers.3.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te_text_model.encoder.layers.3.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te_text_model.encoder.layers.3.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te_text_model.encoder.layers.4.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te_text_model.encoder.layers.4.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te_text_model.encoder.layers.4.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te_text_model.encoder.layers.4.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te_text_model.encoder.layers.4.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te_text_model.encoder.layers.4.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te_text_model.encoder.layers.5.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te_text_model.encoder.layers.5.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te_text_model.encoder.layers.5.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te_text_model.encoder.layers.5.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te_text_model.encoder.layers.5.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te_text_model.encoder.layers.5.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te_text_model.encoder.layers.6.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te_text_model.encoder.layers.6.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te_text_model.encoder.layers.6.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te_text_model.encoder.layers.6.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te_text_model.encoder.layers.6.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te_text_model.encoder.layers.6.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te_text_model.encoder.layers.7.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te_text_model.encoder.layers.7.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te_text_model.encoder.layers.7.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te_text_model.encoder.layers.7.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te_text_model.encoder.layers.7.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te_text_model.encoder.layers.7.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te_text_model.encoder.layers.8.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te_text_model.encoder.layers.8.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te_text_model.encoder.layers.8.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te_text_model.encoder.layers.8.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te_text_model.encoder.layers.8.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te_text_model.encoder.layers.8.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te_text_model.encoder.layers.9.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te_text_model.encoder.layers.9.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te_text_model.encoder.layers.9.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te_text_model.encoder.layers.9.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te_text_model.encoder.layers.9.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te_text_model.encoder.layers.9.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "te_text_model.final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "te_text_model.final_layer_norm.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": {}, + "diffusers_ldm_operator_map": {} +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8e2c4cb90b8d10d6c9a844a3b73ef3e07541f130 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576 +size 16 diff --git a/toolkit/keymaps/stable_diffusion_sd2.json b/toolkit/keymaps/stable_diffusion_sd2.json new file mode 100644 index 0000000000000000000000000000000000000000..868facaf5b6119f5d3a82d369fe509b82da1f551 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sd2.json @@ -0,0 +1,2424 @@ +{ + "ldm_diffusers_keymap": { + "cond_stage_model.model.ln_final.bias": "te_text_model.final_layer_norm.bias", + "cond_stage_model.model.ln_final.weight": "te_text_model.final_layer_norm.weight", + "cond_stage_model.model.positional_embedding": "te_text_model.embeddings.position_embedding.weight", + "cond_stage_model.model.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight", + "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.0.ln_1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.0.ln_1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.0.ln_2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.0.ln_2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.1.ln_1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.1.ln_1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.1.ln_2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.1.ln_2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.10.ln_1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.10.ln_1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.10.ln_2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.10.ln_2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.11.ln_1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.11.ln_1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.11.ln_2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.11.ln_2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias": "te_text_model.encoder.layers.12.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight": "te_text_model.encoder.layers.12.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.12.ln_1.bias": "te_text_model.encoder.layers.12.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.12.ln_1.weight": "te_text_model.encoder.layers.12.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.12.ln_2.bias": "te_text_model.encoder.layers.12.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.12.ln_2.weight": "te_text_model.encoder.layers.12.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.bias": "te_text_model.encoder.layers.12.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight": "te_text_model.encoder.layers.12.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias": "te_text_model.encoder.layers.12.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight": "te_text_model.encoder.layers.12.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias": "te_text_model.encoder.layers.13.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight": "te_text_model.encoder.layers.13.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.13.ln_1.bias": "te_text_model.encoder.layers.13.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.13.ln_1.weight": "te_text_model.encoder.layers.13.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.13.ln_2.bias": "te_text_model.encoder.layers.13.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.13.ln_2.weight": "te_text_model.encoder.layers.13.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.bias": "te_text_model.encoder.layers.13.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight": "te_text_model.encoder.layers.13.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias": "te_text_model.encoder.layers.13.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight": "te_text_model.encoder.layers.13.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias": "te_text_model.encoder.layers.14.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight": "te_text_model.encoder.layers.14.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.14.ln_1.bias": "te_text_model.encoder.layers.14.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.14.ln_1.weight": "te_text_model.encoder.layers.14.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.14.ln_2.bias": "te_text_model.encoder.layers.14.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.14.ln_2.weight": "te_text_model.encoder.layers.14.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.bias": "te_text_model.encoder.layers.14.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight": "te_text_model.encoder.layers.14.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias": "te_text_model.encoder.layers.14.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight": "te_text_model.encoder.layers.14.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias": "te_text_model.encoder.layers.15.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight": "te_text_model.encoder.layers.15.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.15.ln_1.bias": "te_text_model.encoder.layers.15.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.15.ln_1.weight": "te_text_model.encoder.layers.15.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.15.ln_2.bias": "te_text_model.encoder.layers.15.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.15.ln_2.weight": "te_text_model.encoder.layers.15.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.bias": "te_text_model.encoder.layers.15.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight": "te_text_model.encoder.layers.15.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias": "te_text_model.encoder.layers.15.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight": "te_text_model.encoder.layers.15.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias": "te_text_model.encoder.layers.16.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight": "te_text_model.encoder.layers.16.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.16.ln_1.bias": "te_text_model.encoder.layers.16.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.16.ln_1.weight": "te_text_model.encoder.layers.16.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.16.ln_2.bias": "te_text_model.encoder.layers.16.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.16.ln_2.weight": "te_text_model.encoder.layers.16.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.bias": "te_text_model.encoder.layers.16.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight": "te_text_model.encoder.layers.16.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias": "te_text_model.encoder.layers.16.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight": "te_text_model.encoder.layers.16.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias": "te_text_model.encoder.layers.17.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight": "te_text_model.encoder.layers.17.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.17.ln_1.bias": "te_text_model.encoder.layers.17.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.17.ln_1.weight": "te_text_model.encoder.layers.17.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.17.ln_2.bias": "te_text_model.encoder.layers.17.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.17.ln_2.weight": "te_text_model.encoder.layers.17.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.bias": "te_text_model.encoder.layers.17.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight": "te_text_model.encoder.layers.17.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias": "te_text_model.encoder.layers.17.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight": "te_text_model.encoder.layers.17.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias": "te_text_model.encoder.layers.18.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight": "te_text_model.encoder.layers.18.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.18.ln_1.bias": "te_text_model.encoder.layers.18.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.18.ln_1.weight": "te_text_model.encoder.layers.18.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.18.ln_2.bias": "te_text_model.encoder.layers.18.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.18.ln_2.weight": "te_text_model.encoder.layers.18.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.bias": "te_text_model.encoder.layers.18.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight": "te_text_model.encoder.layers.18.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias": "te_text_model.encoder.layers.18.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight": "te_text_model.encoder.layers.18.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias": "te_text_model.encoder.layers.19.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight": "te_text_model.encoder.layers.19.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.19.ln_1.bias": "te_text_model.encoder.layers.19.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.19.ln_1.weight": "te_text_model.encoder.layers.19.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.19.ln_2.bias": "te_text_model.encoder.layers.19.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.19.ln_2.weight": "te_text_model.encoder.layers.19.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.bias": "te_text_model.encoder.layers.19.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight": "te_text_model.encoder.layers.19.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias": "te_text_model.encoder.layers.19.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight": "te_text_model.encoder.layers.19.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.2.ln_1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.2.ln_1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.2.ln_2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.2.ln_2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias": "te_text_model.encoder.layers.20.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight": "te_text_model.encoder.layers.20.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.20.ln_1.bias": "te_text_model.encoder.layers.20.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.20.ln_1.weight": "te_text_model.encoder.layers.20.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.20.ln_2.bias": "te_text_model.encoder.layers.20.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.20.ln_2.weight": "te_text_model.encoder.layers.20.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.bias": "te_text_model.encoder.layers.20.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight": "te_text_model.encoder.layers.20.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias": "te_text_model.encoder.layers.20.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight": "te_text_model.encoder.layers.20.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias": "te_text_model.encoder.layers.21.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight": "te_text_model.encoder.layers.21.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.21.ln_1.bias": "te_text_model.encoder.layers.21.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.21.ln_1.weight": "te_text_model.encoder.layers.21.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.21.ln_2.bias": "te_text_model.encoder.layers.21.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.21.ln_2.weight": "te_text_model.encoder.layers.21.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.bias": "te_text_model.encoder.layers.21.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight": "te_text_model.encoder.layers.21.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias": "te_text_model.encoder.layers.21.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight": "te_text_model.encoder.layers.21.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias": "te_text_model.encoder.layers.22.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight": "te_text_model.encoder.layers.22.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.22.ln_1.bias": "te_text_model.encoder.layers.22.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.22.ln_1.weight": "te_text_model.encoder.layers.22.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.22.ln_2.bias": "te_text_model.encoder.layers.22.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.22.ln_2.weight": "te_text_model.encoder.layers.22.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.bias": "te_text_model.encoder.layers.22.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight": "te_text_model.encoder.layers.22.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias": "te_text_model.encoder.layers.22.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight": "te_text_model.encoder.layers.22.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.3.ln_1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.3.ln_1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.3.ln_2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.3.ln_2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.4.ln_1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.4.ln_1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.4.ln_2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.4.ln_2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.5.ln_1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.5.ln_1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.5.ln_2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.5.ln_2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.6.ln_1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.6.ln_1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.6.ln_2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.6.ln_2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.7.ln_1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.7.ln_1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.7.ln_2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.7.ln_2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.8.ln_1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.8.ln_1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.8.ln_2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.8.ln_2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.9.ln_1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.9.ln_1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.9.ln_2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.9.ln_2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "2048:, :" + ] + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..14d1315934ea605ae2ffdfa143dac5ba4d31788f --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25cdb3685616f5851c554f47beb9ff9c09d0aa3d73e4263b2a94384903dea592 +size 27316630 diff --git a/toolkit/keymaps/stable_diffusion_sd2_unmatched.json b/toolkit/keymaps/stable_diffusion_sd2_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..3814d87e7d37f7a1bc565132baf07a269a30422c --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sd2_unmatched.json @@ -0,0 +1,200 @@ +{ + "ldm": { + "alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.00466156005859375, + "max": 0.9990234375 + }, + "alphas_cumprod_prev": { + "shape": [ + 1000 + ], + "min": 0.0047149658203125, + "max": 1.0 + }, + "betas": { + "shape": [ + 1000 + ], + "min": 0.0008502006530761719, + "max": 0.01200103759765625 + }, + "cond_stage_model.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "cond_stage_model.model.text_projection": { + "shape": [ + 1024, + 1024 + ], + "min": -0.109130859375, + "max": 0.09271240234375 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias": { + "shape": [ + 3072 + ], + "min": -2.525390625, + "max": 2.591796875 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight": { + "shape": [ + 3072, + 1024 + ], + "min": -0.12261962890625, + "max": 0.1258544921875 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias": { + "shape": [ + 1024 + ], + "min": -0.422607421875, + "max": 1.17578125 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight": { + "shape": [ + 1024, + 1024 + ], + "min": -0.0738525390625, + "max": 0.08673095703125 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias": { + "shape": [ + 1024 + ], + "min": -3.392578125, + "max": 0.90625 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight": { + "shape": [ + 1024 + ], + "min": 0.379638671875, + "max": 2.02734375 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias": { + "shape": [ + 1024 + ], + "min": -0.833984375, + "max": 2.525390625 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight": { + "shape": [ + 1024 + ], + "min": 1.17578125, + "max": 2.037109375 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias": { + "shape": [ + 4096 + ], + "min": -1.619140625, + "max": 0.5595703125 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight": { + "shape": [ + 4096, + 1024 + ], + "min": -0.08953857421875, + "max": 0.13232421875 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias": { + "shape": [ + 1024 + ], + "min": -1.8662109375, + "max": 0.74658203125 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight": { + "shape": [ + 1024, + 4096 + ], + "min": -0.12939453125, + "max": 0.1009521484375 + }, + "log_one_minus_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": -7.0703125, + "max": -0.004669189453125 + }, + "model_ema.decay": { + "shape": [], + "min": 1.0, + "max": 1.0 + }, + "model_ema.num_updates": { + "shape": [], + "min": 219996, + "max": 219996 + }, + "posterior_log_variance_clipped": { + "shape": [ + 1000 + ], + "min": -46.0625, + "max": -4.421875 + }, + "posterior_mean_coef1": { + "shape": [ + 1000 + ], + "min": 0.000827789306640625, + "max": 1.0 + }, + "posterior_mean_coef2": { + "shape": [ + 1000 + ], + "min": 0.0, + "max": 0.99560546875 + }, + "posterior_variance": { + "shape": [ + 1000 + ], + "min": 0.0, + "max": 0.01200103759765625 + }, + "sqrt_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0682373046875, + "max": 0.99951171875 + }, + "sqrt_one_minus_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0291595458984375, + "max": 0.99755859375 + }, + "sqrt_recip_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 1.0, + "max": 14.6484375 + }, + "sqrt_recipm1_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0291595458984375, + "max": 14.6171875 + } + }, + "diffusers": {} +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_sdxl.json b/toolkit/keymaps/stable_diffusion_sdxl.json new file mode 100644 index 0000000000000000000000000000000000000000..dd3c24475b9a933839567e20990d0910944ba82e --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sdxl.json @@ -0,0 +1,4154 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..16f4d21046ef187cb3dd34d83b7eaa3aea216394 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:243672eb340dae3396626886ae9c270ad0d212b9df970d3021037f829d8c70a5 +size 3277308 diff --git a/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json b/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..d0b2554ae6e6fd8bd12d660cdb64437132c8b52d --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json @@ -0,0 +1,35 @@ +{ + "ldm": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + }, + "conditioner.embedders.1.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "conditioner.embedders.1.model.text_projection": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + }, + "diffusers": { + "te1_text_projection.weight": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_ssd.json b/toolkit/keymaps/stable_diffusion_ssd.json new file mode 100644 index 0000000000000000000000000000000000000000..9ad06407be7c6eedb4fcfa06805827bf1d2f6924 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_ssd.json @@ -0,0 +1,3419 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2048:, :" + ] + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..3936b58486f98f1199a5cd7358fd3846ae61199d --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0df922c1d1dd2ff13d557ac95a8c867ba3319a87f1a3d74aeb2022a64361f914 +size 572 diff --git a/toolkit/keymaps/stable_diffusion_ssd_unmatched.json b/toolkit/keymaps/stable_diffusion_ssd_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..6871c9eb6af0c9a4018e1b0ead6c9fd7c7ee387b --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_ssd_unmatched.json @@ -0,0 +1,21 @@ +{ + "ldm": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + }, + "conditioner.embedders.1.model.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + } + }, + "diffusers": {} +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_vega.json b/toolkit/keymaps/stable_diffusion_vega.json new file mode 100644 index 0000000000000000000000000000000000000000..4117c201963bea780c16acd720055699b92acf43 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_vega.json @@ -0,0 +1,3039 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8e2c4cb90b8d10d6c9a844a3b73ef3e07541f130 --- /dev/null +++ b/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576 +size 16 diff --git a/toolkit/kohya_model_util.py b/toolkit/kohya_model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..798fc2dccb5787973cdf2bbab769459f22b3a805 --- /dev/null +++ b/toolkit/kohya_model_util.py @@ -0,0 +1,1533 @@ +# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py +# I am infinitely grateful to @kohya-ss for their amazing work in this field. +# This version is updated to handle the latest version of the diffusers library. +import json +# v1: split from train_db_fixed.py. +# v2: support safetensors + +import math +import os +import re + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" + + +# region StableDiffusion->Diffusersの変換コード +# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # updated for latest diffusers + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + mapping = {} + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in + range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in + range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in + range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if + f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight" + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias") + mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias" + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get('use_linear_projection', False): + linear_transformer_to_conv(new_checkpoint) + + # print("mapping: ", json.dumps(mapping, indent=4)) + return new_checkpoint + + +# ldm key: diffusers key +vae_ldm_to_diffusers_dict = { + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias", + "decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight", + "decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias", + "decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight", + "decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias", + "decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight", + "decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias", + "decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight", + "decoder.norm_out.bias": "decoder.conv_norm_out.bias", + "decoder.norm_out.weight": "decoder.conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight", + "decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias", + "decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight", + "decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight", + "decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias", + "decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight", + "decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight", + "decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias", + "decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias", + "encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight", + "encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight", + "encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias", + "encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight", + "encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight", + "encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias", + "encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight", + "encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight", + "encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias", + "encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight", + "encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias", + "encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias", + "encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight", + "encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias", + "encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight", + "encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight", + "encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias", + "encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight", + "encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias", + "encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight", + "encoder.norm_out.bias": "encoder.conv_norm_out.bias", + "encoder.norm_out.weight": "encoder.conv_norm_out.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "quant_conv.weight": "quant_conv.weight" +} + + +def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if i is not None: + ldm_key = ldm_key.replace("{i}", str(i)) + diffusers_key = diffusers_key.replace("{i}", str(i)) + if ldm_key == target_ldm_key: + return diffusers_key + + if ldm_key in vae_ldm_to_diffusers_dict: + return vae_ldm_to_diffusers_dict[ldm_key] + else: + return None + +# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): +# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): +# if diffusers_key == target_diffusers_key: +# return ldm_key +# return None + +def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if "{" in diffusers_key: # if we have a placeholder + # escape special characters in the key, and replace the placeholder with a regex group + pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)") + match = re.match(pattern, target_diffusers_key) + if match: # if we found a match + return ldm_key.format(i=match.group(1)) + elif diffusers_key == target_diffusers_key: + return ldm_key + return None + + +vae_keys_squished_on_diffusers = [ + "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid_block.attentions.0.to_v.weight", + "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid_block.attentions.0.to_v.weight" +] + +def convert_diffusers_back_to_ldm(diffusers_vae): + new_state_dict = OrderedDict() + diffusers_state_dict = diffusers_vae.state_dict() + for key, value in diffusers_state_dict.items(): + val_to_save = value + if key in vae_keys_squished_on_diffusers: + val_to_save = value.clone() + # (512, 512) diffusers and (512, 512, 1, 1) ldm + val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1) + ldm_key = get_ldm_vae_key_from_diffusers_key(key) + if ldm_key is not None: + new_state_dict[ldm_key] = val_to_save + else: + # for now add current key + new_state_dict[key] = val_to_save + return new_state_dict + + +def convert_ldm_vae_checkpoint(checkpoint, config): + mapping = {} + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + # for key in list(vae_state_dict.keys()): + # diffusers_key = get_diffusers_vae_key_from_ldm_key(key) + # if diffusers_key is not None: + # new_checkpoint[diffusers_key] = vae_state_dict[key] + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in + range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in + range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if + f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の変換コード +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3 - i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i + 1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自作のモデル読み書きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from):] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, + unet_use_linear_projection_in_v2=False): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() + + # latest transformers doesnt have position ids. Do we remove it? + if "text_model.embeddings.position_ids" not in text_model.state_dict(): + del converted_text_encoder_checkpoint["text_model.embeddings.position_ids"] + + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) + + return text_model, vae, unet + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, + vae=None): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, + use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# endregion + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + size += divisible + + resos = list(resos) + resos.sort() + return resos + + +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/toolkit/layers.py b/toolkit/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc975bfb76ee564021c7ea823a3cdba09aeba48 --- /dev/null +++ b/toolkit/layers.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.utils.checkpoint import checkpoint + + +class ReductionKernel(nn.Module): + # Tensorflow + def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + super(ReductionKernel, self).__init__() + self.kernel_size = kernel_size + self.in_channels = in_channels + numpy_kernel = self.build_kernel() + self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + + def build_kernel(self): + # tensorflow kernel is (height, width, in_channels, out_channels) + # pytorch kernel is (out_channels, in_channels, height, width) + kernel_size = self.kernel_size + channels = self.in_channels + kernel_shape = [channels, channels, kernel_size, kernel_size] + kernel = np.zeros(kernel_shape, np.float32) + + kernel_value = 1.0 / (kernel_size * kernel_size) + for i in range(0, channels): + kernel[i, i, :, :] = kernel_value + return kernel + + def forward(self, x): + return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) + + +class CheckpointGradients(nn.Module): + def __init__(self, is_gradient_checkpointing=True): + super(CheckpointGradients, self).__init__() + self.is_gradient_checkpointing = is_gradient_checkpointing + + def forward(self, module, *args, num_chunks=1): + if self.is_gradient_checkpointing: + return checkpoint(module, *args, num_chunks=self.num_chunks) + else: + return module(*args) diff --git a/toolkit/llvae.py b/toolkit/llvae.py new file mode 100644 index 0000000000000000000000000000000000000000..9d559bfea01676ad9a9c255d930d693099b1a2c9 --- /dev/null +++ b/toolkit/llvae.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import numpy as np +import itertools + + +class LosslessLatentDecoder(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentDecoder, self).__init__() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.latent_depth = latent_depth + self.in_channels = in_channels + self.out_channels = int(in_channels // (latent_depth * latent_depth)) + numpy_kernel = self.build_kernel(in_channels, latent_depth) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel + + def build_kernel(self, in_channels, latent_depth): + # my old code from tensorflow. + # tensorflow kernel is (height, width, out_channels, in_channels) + # pytorch kernel is (in_channels, out_channels, height, width) + out_channels = self.out_channels + + # kernel_shape = [kernel_filter_size, kernel_filter_size, out_channels, in_channels] # tensorflow + kernel_shape = [in_channels, out_channels, latent_depth, latent_depth] # pytorch + kernel = np.zeros(kernel_shape, np.float32) + + # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel. + for c in range(0, out_channels): + i = 0 + for x, y in itertools.product(range(latent_depth), repeat=2): + # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow + kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch + i += 1 + + return kernel + + def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) + + # Deconvolve input tensor with the kernel + return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) + + +class LosslessLatentEncoder(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentEncoder, self).__init__() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.latent_depth = latent_depth + self.in_channels = in_channels + self.out_channels = int(in_channels * (latent_depth * latent_depth)) + numpy_kernel = self.build_kernel(in_channels, latent_depth) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel + + + def build_kernel(self, in_channels, latent_depth): + # my old code from tensorflow. + # tensorflow kernel is (height, width, in_channels, out_channels) + # pytorch kernel is (out_channels, in_channels, height, width) + out_channels = self.out_channels + + # kernel_shape = [latent_depth, latent_depth, in_channels, out_channels] # tensorflow + kernel_shape = [out_channels, in_channels, latent_depth, latent_depth] # pytorch + kernel = np.zeros(kernel_shape, np.float32) + + # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel. + for c in range(0, in_channels): + i = 0 + for x, y in itertools.product(range(latent_depth), repeat=2): + # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow + kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch + i += 1 + return kernel + + def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) + # Convolve input tensor with the kernel + return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) + + +class LosslessLatentVAE(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentVAE, self).__init__() + self.latent_depth = latent_depth + self.in_channels = in_channels + self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable) + encoder_out_channels = self.encoder.out_channels + self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable) + + def forward(self, x): + latent = self.latent_encoder(x) + out = self.latent_decoder(latent) + return out + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + + +# test it +if __name__ == '__main__': + import os + from PIL import Image + import torchvision.transforms as transforms + user_path = os.path.expanduser('~') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + input_path = os.path.join(user_path, "Pictures/sample_2_512.png") + output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png") + img = Image.open(input_path) + img_tensor = transforms.ToTensor()(img) + img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) + print("input_shape: ", list(img_tensor.shape)) + vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype) + latent = vae.encode(img_tensor) + print("latent_shape: ", list(latent.shape)) + out_tensor = vae.decode(latent) + print("out_shape: ", list(out_tensor.shape)) + + mse_loss = nn.MSELoss() + mse = mse_loss(img_tensor, out_tensor) + print("roundtrip_loss: ", mse.item()) + out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) + out_img.save(output_path) diff --git a/toolkit/logging.py b/toolkit/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..56b1c8b52a301c50fae9d016e74f9aa65c974882 --- /dev/null +++ b/toolkit/logging.py @@ -0,0 +1,84 @@ +from typing import OrderedDict, Optional +from PIL import Image + +from toolkit.config_modules import LoggingConfig + +# Base logger class +# This class does nothing, it's just a placeholder +class EmptyLogger: + def __init__(self, *args, **kwargs) -> None: + pass + + # start logging the training + def start(self): + pass + + # collect the log to send + def log(self, *args, **kwargs): + pass + + # send the log + def commit(self, step: Optional[int] = None): + pass + + # log image + def log_image(self, *args, **kwargs): + pass + + # finish logging + def finish(self): + pass + +# Wandb logger class +# This class logs the data to wandb +class WandbLogger(EmptyLogger): + def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None: + self.project = project + self.run_name = run_name + self.config = config + + def start(self): + try: + import wandb + except ImportError: + raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`") + + # send the whole config to wandb + run = wandb.init(project=self.project, name=self.run_name, config=self.config) + self.run = run + self._log = wandb.log # log function + self._image = wandb.Image # image object + + def log(self, *args, **kwargs): + # when commit is False, wandb increments the step, + # but we don't want that to happen, so we set commit=False + self._log(*args, **kwargs, commit=False) + + def commit(self, step: Optional[int] = None): + # after overall one step is done, we commit the log + # by log empty object with commit=True + self._log({}, step=step, commit=True) + + def log_image( + self, + image: Image, + id, # sample index + caption: str | None = None, # positive prompt + *args, + **kwargs, + ): + # create a wandb image object and log it + image = self._image(image, caption=caption, *args, **kwargs) + self._log({f"sample_{id}": image}, commit=False) + + def finish(self): + self.run.finish() + +# create logger based on the logging config +def create_logger(logging_config: LoggingConfig, all_config: OrderedDict): + if logging_config.use_wandb: + project_name = logging_config.project_name + run_name = logging_config.run_name + return WandbLogger(project=project_name, run_name=run_name, config=all_config) + else: + return EmptyLogger() diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py new file mode 100644 index 0000000000000000000000000000000000000000..6c53439a5b6a0c47a74f32876c76ee92cc63bbaa --- /dev/null +++ b/toolkit/lora_special.py @@ -0,0 +1,505 @@ +import copy +import json +import math +import weakref +import os +import re +import sys +from typing import List, Optional, Dict, Type, Union +import torch +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel +from transformers import CLIPTextModel + +from .config_modules import NetworkConfig +from .lorm import count_parameters +from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin +from .paths import SD_SCRIPTS_ROOT + +sys.path.append(SD_SCRIPTS_ROOT) + +from networks.lora import LoRANetwork, get_block_index +from toolkit.models.DoRA import DoRAModule + +from torch.utils.checkpoint import checkpoint + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear', + 'QLinear', + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv', + 'QConv2d', +] + +class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + network: 'LoRASpecialNetwork' = None, + use_bias: bool = False, + **kwargs + ): + self.can_merge_in = True + """if alpha == 0 or None, alpha is rank (no scaling).""" + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.orig_module_ref = weakref.ref(org_module) + self.scalar = torch.tensor(1.0) + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False + + if org_module.__class__.__name__ in CONV_MODULES: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ in CONV_MODULES: + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier: Union[float, List[float]] = multiplier + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module + + +class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + + # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"] + UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + PEFT_PREFIX_UNET = "unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + train_text_encoder: Optional[bool] = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, + train_unet: Optional[bool] = True, + is_sdxl=False, + is_v2=False, + is_v3=False, + is_pixart: bool = False, + is_auraflow: bool = False, + is_flux: bool = False, + use_bias: bool = False, + is_lorm: bool = False, + ignore_if_contains = None, + only_if_contains = None, + parameter_threshold: float = 0.0, + attn_only: bool = False, + target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, + target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, + network_type: str = "lora", + full_train_in_out: bool = False, + transformer_only: bool = False, + peft_format: bool = False, + is_assistant_adapter: bool = False, + **kwargs + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + # call the parent of the parent we are replacing (LoRANetwork) init + torch.nn.Module.__init__(self) + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_sdxl=is_sdxl, + is_v2=is_v2, + is_lorm=is_lorm, + **kwargs + ) + if ignore_if_contains is None: + ignore_if_contains = [] + self.ignore_if_contains = ignore_if_contains + self.transformer_only = transformer_only + + self.only_if_contains: Union[List, None] = only_if_contains + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self.torch_multiplier = None + # triggers the state updates + self.multiplier = multiplier + self.is_sdxl = is_sdxl + self.is_v2 = is_v2 + self.is_v3 = is_v3 + self.is_pixart = is_pixart + self.is_auraflow = is_auraflow + self.is_flux = is_flux + self.network_type = network_type + self.is_assistant_adapter = is_assistant_adapter + if self.network_type.lower() == "dora": + self.module_class = DoRAModule + module_class = DoRAModule + + self.peft_format = peft_format + + # always do peft for flux only for now + if self.is_flux or self.is_v3: + self.peft_format = True + + if self.peft_format: + # no alpha for peft + self.alpha = self.lora_dim + alpha = self.alpha + self.conv_alpha = self.conv_lora_dim + conv_alpha = self.conv_alpha + + self.full_train_in_out = full_train_in_out + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + unet_prefix = self.LORA_PREFIX_UNET + if self.peft_format: + unet_prefix = self.PEFT_PREFIX_UNET + if is_pixart or is_v3 or is_auraflow or is_flux: + unet_prefix = f"lora_transformer" + if self.peft_format: + unet_prefix = "transformer" + + prefix = ( + unet_prefix + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + attached_modules = [] + lora_shape_dict = {} + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ in LINEAR_MODULES + is_conv2d = child_module.__class__.__name__ in CONV_MODULES + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + + lora_name = [prefix, name, child_name] + # filter out blank + lora_name = [x for x in lora_name if x and x != ""] + lora_name = ".".join(lora_name) + # if it doesnt have a name, it wil have two dots + lora_name.replace("..", ".") + clean_name = lora_name + if self.peft_format: + # we replace this on saving + lora_name = lora_name.replace(".", "$$") + else: + lora_name = lora_name.replace(".", "_") + + skip = False + if any([word in clean_name for word in self.ignore_if_contains]): + skip = True + + # see if it is over threshold + if count_parameters(child_module) < parameter_threshold: + skip = True + + if self.transformer_only and self.is_pixart and is_unet: + if "transformer_blocks" not in lora_name: + skip = True + if self.transformer_only and self.is_flux and is_unet: + if "transformer_blocks" not in lora_name: + skip = True + if self.transformer_only and self.is_v3 and is_unet: + if "transformer_blocks" not in lora_name: + skip = True + + if (is_linear or is_conv2d) and not skip: + + if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or ( + self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + network=self, + parent=module, + use_bias=use_bias, + ) + loras.append(lora) + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape) + ] + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + if train_text_encoder: + for i, text_encoder in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + + if self.is_pixart: + replace_modules = ["T5EncoderModel"] + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = target_lin_modules + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += target_conv_modules + + if is_v3: + target_modules = ["SD3Transformer2DModel"] + + if is_pixart: + target_modules = ["PixArtTransformer2DModel"] + + if is_auraflow: + target_modules = ["AuraFlowTransformer2DModel"] + + if is_flux: + target_modules = ["FluxTransformer2DModel"] + + if train_unet: + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + else: + self.unet_loras = [] + skipped_un = [] + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + if self.full_train_in_out: + print("full train in out") + # we are going to retrain the main in out layers for VAE change usually + if self.is_pixart: + transformer: PixArtTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + + elif self.is_auraflow: + transformer: AuraFlowTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + + else: + unet: UNet2DConditionModel = unet + unet_conv_in: torch.nn.Conv2d = unet.conv_in + unet_conv_out: torch.nn.Conv2d = unet.conv_out + + # clone these and replace their forwards with ours + self.unet_conv_in = copy.deepcopy(unet_conv_in) + self.unet_conv_out = copy.deepcopy(unet_conv_out) + unet.conv_in = self.unet_conv_in + unet.conv_out = self.unet_conv_out + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # call Lora prepare_optimizer_params + all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) + + if self.full_train_in_out: + if self.is_pixart or self.is_auraflow or self.is_flux: + all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) + else: + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())}) + + return all_params + + diff --git a/toolkit/lorm.py b/toolkit/lorm.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfdb516be12d98e464a7d9a96bc4e19b83b9f91 --- /dev/null +++ b/toolkit/lorm.py @@ -0,0 +1,461 @@ +from typing import Union, Tuple, Literal, Optional + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from torch import Tensor +from tqdm import tqdm + +from toolkit.config_modules import LoRMConfig + +conv = nn.Conv2d +lin = nn.Linear +_size_2_t = Union[int, Tuple[int, int]] + +ExtractMode = Union[ + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' +] +CONV_MODULES = [ + # 'Conv2d', + # 'LoRACompatibleConv' +] + +UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + # "ResnetBlock2D", + "Downsample2D", + "Upsample2D", +] + +LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE + +UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", +] + +UNET_MODULES_TO_AVOID = [ +] + + +# Low Rank Convolution +class LoRMCon2d(nn.Module): + def __init__( + self, + in_channels: int, + lorm_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 'same', + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_channels = in_channels + self.lorm_channels = lorm_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.down = nn.Conv2d( + in_channels=in_channels, + out_channels=lorm_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + padding_mode=padding_mode, + device=device, + dtype=dtype + ) + + # Kernel size on the up is always 1x1. + # I don't think you could calculate a dual 3x3, or I can't at least + + self.up = nn.Conv2d( + in_channels=lorm_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=1, + padding='same', + dilation=1, + groups=1, + bias=bias, + padding_mode='zeros', + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +class LoRMLinear(nn.Module): + def __init__( + self, + in_features: int, + lorm_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_features = in_features + self.lorm_features = lorm_features + self.out_features = out_features + + self.down = nn.Linear( + in_features=in_features, + out_features=lorm_features, + bias=False, + device=device, + dtype=dtype + + ) + self.up = nn.Linear( + in_features=lorm_features, + out_features=out_features, + bias=bias, + # bias=True, + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu' +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + original_params = out_ch * in_ch * kernel_size * kernel_size + desired_params = mode_param * original_params + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + lora_rank = int(out_ch / 2) + print(f"rank is higher than it should be") + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None + # return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = torch.linalg.svd(weight) + + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + desired_params = mode_param * out_ch * in_ch + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + # print(f"rank is higher than it should be") + lora_rank = int(out_ch / 2) + # return weight, 'full' + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def replace_module_by_path(network, name, module): + """Replace a module in a network by its name.""" + name_parts = name.split('.') + current_module = network + for part in name_parts[:-1]: + current_module = getattr(current_module, part) + try: + setattr(current_module, name_parts[-1], module) + except Exception as e: + print(e) + + +def count_parameters(module): + return sum(p.numel() for p in module.parameters()) + + +def compute_optimal_bias(original_module, linear_down, linear_up, X): + Y_original = original_module(X) + Y_approx = linear_up(linear_down(X)) + E = Y_original - Y_approx + + optimal_bias = E.mean(dim=0) + + return optimal_bias + + +def format_with_commas(n): + return f"{n:,}" + + +def print_lorm_extract_details( + start_num_params: int, + end_num_params: int, + num_replaced: int, +): + start_formatted = format_with_commas(start_num_params) + end_formatted = format_with_commas(end_num_params) + num_replaced_formatted = format_with_commas(num_replaced) + + width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) + + print(f"Convert UNet result:") + print(f" - converted: {num_replaced:>{width},} modules") + print(f" - start: {start_num_params:>{width},} params") + print(f" - end: {end_num_params:>{width},} params") + + +lorm_ignore_if_contains = [ + 'proj_out', 'proj_in', +] + +lorm_parameter_threshold = 1000000 + + +@torch.no_grad() +def convert_diffusers_unet_to_lorm( + unet: UNet2DConditionModel, + config: LoRMConfig, +): + print('Converting UNet to LoRM UNet') + start_num_params = count_parameters(unet) + named_modules = list(unet.named_modules()) + + num_replaced = 0 + + pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") + layer_names_replaced = [] + converted_modules = [] + ignore_if_contains = [ + 'proj_out', 'proj_in', + ] + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in UNET_TARGET_REPLACE_MODULE: + for child_name, child_module in module.named_modules(): + new_module: Union[LoRMCon2d, LoRMLinear, None] = None + # if child name includes attn, skip it + combined_name = combined_name = f"{name}.{child_name}" + # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: + # pass + + lorm_config = config.get_config_for_module(combined_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + + if any([word in child_name for word in ignore_if_contains]): + pass + + elif child_module.__class__.__name__ in LINEAR_MODULES: + if count_parameters(child_module) > parameter_threshold: + + # dtype = child_module.weight.dtype + dtype = torch.float32 + # extract and convert + down_weight, up_weight, lora_dim, diff = extract_linear( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=child_module.weight.device, + ) + if down_weight is None: + continue + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + # linear layer weights = (out_features, in_features) + new_module = LoRMLinear( + in_features=down_weight.shape[1], + lorm_features=lora_dim, + out_features=up_weight.shape[0], + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + # else: + # new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data) + + # bias_correction = compute_optimal_bias( + # child_module, + # new_module.down, + # new_module.up, + # torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype) + # ) + # new_module.up.bias.data += bias_correction + + elif child_module.__class__.__name__ in CONV_MODULES: + if count_parameters(child_module) > parameter_threshold: + dtype = child_module.weight.dtype + down_weight, up_weight, lora_dim, diff = extract_conv( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=child_module.weight.device, + ) + if down_weight is None: + continue + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + + new_module = LoRMCon2d( + in_channels=down_weight.shape[1], + lorm_channels=lora_dim, + out_channels=up_weight.shape[0], + kernel_size=child_module.kernel_size, + dilation=child_module.dilation, + padding=child_module.padding, + padding_mode=child_module.padding_mode, + stride=child_module.stride, + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + + if new_module: + combined_name = f"{name}.{child_name}" + replace_module_by_path(unet, combined_name, new_module) + converted_modules.append(new_module) + num_replaced += 1 + layer_names_replaced.append( + f"{combined_name} - {format_with_commas(count_parameters(child_module))}") + + pbar.update(1) + pbar.close() + end_num_params = count_parameters(unet) + + def sorting_key(s): + # Extract the number part, remove commas, and convert to integer + return int(s.split("-")[1].strip().replace(",", "")) + + sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) + for layer_name in sorted_layer_names_replaced: + print(layer_name) + + print_lorm_extract_details( + start_num_params=start_num_params, + end_num_params=end_num_params, + num_replaced=num_replaced, + ) + + return converted_modules diff --git a/toolkit/losses.py b/toolkit/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..eeea357111f38b54f6b79ea3e73f23f43ba2dbd7 --- /dev/null +++ b/toolkit/losses.py @@ -0,0 +1,97 @@ +import torch +from .llvae import LosslessLatentEncoder + + +def total_variation(image): + """ + Compute normalized total variation. + Inputs: + - image: PyTorch Variable of shape (N, C, H, W) + Returns: + - TV: total variation normalized by the number of elements + """ + n_elements = image.shape[1] * image.shape[2] * image.shape[3] + return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + + +class ComparativeTotalVariation(torch.nn.Module): + """ + Compute the comparative loss in tv between two images. to match their tv + """ + + def forward(self, pred, target): + return torch.abs(total_variation(pred) - total_variation(target)) + + +# Gradient penalty +def get_gradient_penalty(critic, real, fake, device): + with torch.autocast(device_type='cuda'): + real = real.float() + fake = fake.float() + alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float() + interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True) + if torch.isnan(interpolates).any(): + print('d_interpolates is nan') + d_interpolates = critic(interpolates) + fake = torch.ones(real.size(0), 1, device=device) + + if torch.isnan(d_interpolates).any(): + print('fake is nan') + gradients = torch.autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + # see if any are nan + if torch.isnan(gradients).any(): + print('gradients is nan') + + gradients = gradients.view(gradients.size(0), -1) + gradient_norm = gradients.norm(2, dim=1) + gradient_penalty = ((gradient_norm - 1) ** 2).mean() + return gradient_penalty.float() + + +class PatternLoss(torch.nn.Module): + def __init__(self, pattern_size=4, dtype=torch.float32): + super().__init__() + self.pattern_size = pattern_size + self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype) + + def forward(self, pred, target): + pred_latents = self.llvae_encoder(pred) + target_latents = self.llvae_encoder(target) + + matrix_pixels = self.pattern_size * self.pattern_size + + color_chans = pred_latents.shape[1] // 3 + # pytorch + r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1) + r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1) + + def separated_chan_loss(latent_chan): + nonlocal matrix_pixels + chan_mean = torch.mean(latent_chan, dim=[1, 2, 3]) + chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1) + chan_loss = None + for chan in chan_splits: + this_mean = torch.mean(chan, dim=[1, 2, 3]) + this_chan_loss = torch.abs(this_mean - chan_mean) + if chan_loss is None: + chan_loss = this_chan_loss + else: + chan_loss = chan_loss + this_chan_loss + chan_loss = chan_loss * (1 / matrix_pixels) + return chan_loss + + r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target)) + g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) + b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) + return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 + + diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py new file mode 100644 index 0000000000000000000000000000000000000000..84021b49cdd3853972721924c1f957203e17e49d --- /dev/null +++ b/toolkit/lycoris_special.py @@ -0,0 +1,373 @@ +import math +import os +from typing import Optional, Union, List, Type + +import torch +from lycoris.kohya import LycorisNetwork, LoConModule +from lycoris.modules.glora import GLoRAModule +from torch import nn +from transformers import CLIPTextModel +from torch.nn import functional as F +from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin): + def __init__( + self, + lora_name, org_module: nn.Module, + multiplier=1.0, + lora_dim=4, alpha=1, + dropout=0., rank_dropout=0., module_dropout=0., + use_cp=False, + network: 'LycorisSpecialNetwork' = None, + use_bias=False, + **kwargs, + ): + """ if alpha == 0 or None, alpha is rank (no scaling). """ + # call super of super + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.lora_dim = lora_dim + self.cp = False + + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False + + self.scalar = nn.Parameter(torch.tensor(0.0)) + orig_module_name = org_module.__class__.__name__ + if orig_module_name in CONV_MODULES: + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + self.down_op = F.conv2d + self.up_op = F.conv2d + if use_cp and k_size != (1, 1): + self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) + self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False) + self.cp = True + else: + self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias) + elif orig_module_name in LINEAR_MODULES: + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + if orig_module_name == 'GroupNorm': + # RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32) + in_dim = org_module.num_channels + out_dim = org_module.num_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias) + else: + raise NotImplementedError + self.shape = org_module.weight.shape + + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = nn.Identity() + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lora_up.weight) + if self.cp: + torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) + + self.multiplier = multiplier + self.org_module = [org_module] + self.register_load_state_dict_post_hook(self.load_weight_hook) + + def load_weight_hook(self, *args, **kwargs): + self.scalar = nn.Parameter(torch.ones_like(self.scalar)) + + +class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + # 'UNet2DConditionModel', + # 'Conv2d', + # 'Timesteps', + # 'TimestepEmbedding', + # 'Linear', + # 'SiLU', + # 'ModuleList', + # 'DownBlock2D', + # 'ResnetBlock2D', # need + # 'GroupNorm', + # 'LoRACompatibleConv', + # 'LoRACompatibleLinear', + # 'Dropout', + # 'CrossAttnDownBlock2D', # needed + # 'Transformer2DModel', # maybe not, has duplicates + # 'BasicTransformerBlock', # duplicates + # 'LayerNorm', + # 'Attention', + # 'FeedForward', + # 'GEGLU', + # 'UpBlock2D', + # 'UNetMidBlock2DCrossAttn' + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + use_cp: Optional[bool] = False, + network_module: Type[object] = LoConSpecialModule, + train_unet: bool = True, + train_text_encoder: bool = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, + use_bias: bool = False, + is_lorm: bool = False, + **kwargs, + ) -> None: + # call ToolkitNetworkMixin super + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_lorm=is_lorm, + **kwargs + ) + # call the parent of the parent LycorisNetwork + torch.nn.Module.__init__(self) + + # LyCORIS unique stuff + if dropout is None: + dropout = 0 + if rank_dropout is None: + rank_dropout = 0 + if module_dropout is None: + module_dropout = 0 + self.train_unet = train_unet + self.train_text_encoder = train_text_encoder + + self.torch_multiplier = None + # triggers a tensor update + self.multiplier = multiplier + self.lora_dim = lora_dim + + if not self.ENABLE_CONV or conv_lora_dim is None: + conv_lora_dim = 0 + conv_alpha = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + print('Apply different lora dim for conv layer') + print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}') + elif self.conv_lora_dim == 0: + print('Disable conv layer') + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + print('Apply different alpha value for conv layer') + print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}') + + if 1 >= dropout >= 0: + print(f'Use Dropout value: {dropout}') + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[] + ) -> List[network_module]: + print('Create LyCORIS Module') + loras = [] + # remove this + named_modules = root_module.named_modules() + # add a few to tthe generator + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in target_replace_modules: + if module_name in self.MODULE_ALGO_MAP: + algo = self.MODULE_ALGO_MAP[module_name] + else: + algo = network_module + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'): + print(f"{lora_name}") + + if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif child_module.__class__.__name__ in CONV_MODULES: + k_size, *_ = child_module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + elif name in target_replace_names: + if name in self.NAME_ALGO_MAP: + algo = self.NAME_ALGO_MAP[name] + else: + algo = network_module + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + if module.__class__.__name__ == 'Linear' and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + parent=module, + network=self, + use_bias=use_bias, + **kwargs + ) + elif module.__class__.__name__ == 'Conv2d': + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + return loras + + if network_module == GLoRAModule: + print('GLoRA enabled, only train transformer') + # only train transformer (for GLoRA) + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + ] + LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = [] + + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + self.text_encoder_loras = [] + if self.train_text_encoder: + for i, te in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + self.text_encoder_loras.extend(create_modules( + LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), + te, + LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + )) + print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.") + if self.train_unet: + self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) + else: + self.unet_loras = [] + print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af11ee9ef52c9c0a42ac34afc3e3fa42c7c4b83d --- /dev/null +++ b/toolkit/lycoris_utils.py @@ -0,0 +1,536 @@ +# heavily based on https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/utils.py + +from typing import * + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.linalg as linalg + +from tqdm import tqdm +from collections import OrderedDict + + +def make_sparse(t: torch.Tensor, sparsity=0.95): + abs_t = torch.abs(t) + np_array = abs_t.detach().cpu().numpy() + quan = float(np.quantile(np_array, sparsity)) + sparse_t = t.masked_fill(abs_t < quan, 0) + return sparse_t + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', + is_cp=False, +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) + + if mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2 and not is_cp: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = linalg.svd(weight) + + if mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_diff( + base_model, + db_model, + mode='fixed', + linear_mode_param=0, + conv_mode_param=0, + extract_device='cpu', + use_bias=False, + sparsity=0.98, + small_conv=True, + linear_only=False, + extract_unet=True, + extract_text_encoder=True, +): + meta = OrderedDict() + + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + if linear_only: + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + ] + + if not extract_unet: + UNET_TARGET_REPLACE_MODULE = [] + UNET_TARGET_REPLACE_NAME = [] + + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + + if not extract_text_encoder: + TEXT_ENCODER_TARGET_REPLACE_MODULE = [] + + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[] + ): + loras = {} + temp = {} + temp_name = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = {} + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + continue + temp[name][child_name] = child_module.weight + elif name in target_replace_names: + temp_name[name] = module.weight + + for name, module in tqdm(list(target_module.named_modules())): + if name in temp: + weights = temp[name] + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + layer = child_module.__class__.__name__ + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + root_weight = child_module.weight + if torch.allclose(root_weight, weights[child_name]): + continue + + if layer == 'Linear' or layer == 'LoRACompatibleLinear': + weight, decompose_mode = extract_linear( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': + is_linear = (child_module.weight.shape[2] == 1 + and child_module.weight.shape[3] == 1) + if not is_linear and linear_only: + continue + weight, decompose_mode = extract_conv( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = child_module.weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + elif name in temp_name: + weights = temp_name[name] + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + layer = module.__class__.__name__ + + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + root_weight = module.weight + if torch.allclose(root_weight, weights): + continue + + if layer == 'Linear' or layer == 'LoRACompatibleLinear': + weight, decompose_mode = extract_linear( + (root_weight - weights), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': + is_linear = ( + root_weight.shape[2] == 1 + and root_weight.shape[3] == 1 + ) + if not is_linear and linear_only: + continue + weight, decompose_mode = extract_conv( + (root_weight - weights), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = root_weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + return loras + + text_encoder_loras = make_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], db_model[0], + TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + + unet_loras = make_state_dict( + LORA_PREFIX_UNET, + base_model[2], db_model[2], + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(len(text_encoder_loras), len(unet_loras)) + # the | will + return (text_encoder_loras | unet_loras), meta + + +def get_module( + lyco_state_dict: Dict, + lora_name +): + if f'{lora_name}.lora_up.weight' in lyco_state_dict: + up = lyco_state_dict[f'{lora_name}.lora_up.weight'] + down = lyco_state_dict[f'{lora_name}.lora_down.weight'] + mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'locon', (up, down, mid, alpha) + elif f'{lora_name}.hada_w1_a' in lyco_state_dict: + w1a = lyco_state_dict[f'{lora_name}.hada_w1_a'] + w1b = lyco_state_dict[f'{lora_name}.hada_w1_b'] + w2a = lyco_state_dict[f'{lora_name}.hada_w2_a'] + w2b = lyco_state_dict[f'{lora_name}.hada_w2_b'] + t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.weight' in lyco_state_dict: + weight = lyco_state_dict[f'{lora_name}.weight'] + on_input = lyco_state_dict.get(f'{lora_name}.on_input', False) + return 'ia3', (weight, on_input) + elif (f'{lora_name}.lokr_w1' in lyco_state_dict + or f'{lora_name}.lokr_w1_a' in lyco_state_dict): + w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None) + w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None) + w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None) + w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None) + w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None) + w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None) + t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.diff' in lyco_state_dict: + return 'full', lyco_state_dict[f'{lora_name}.diff'] + else: + return 'None', () + + +def cp_weight_from_conv( + up, down, mid +): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down) + + +def cp_weight( + wa, wb, t +): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +@torch.no_grad() +def rebuild_weight(module_type, params, orig_weight, scale=1): + if orig_weight is None: + return orig_weight + merged = orig_weight + if module_type == 'locon': + up, down, mid, alpha = params + if alpha is not None: + scale *= alpha / up.size(1) + if mid is not None: + rebuild = cp_weight_from_conv(up, down, mid) + else: + rebuild = up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1) + merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale + del up, down, mid, alpha, params, rebuild + elif module_type == 'hada': + w1a, w1b, w2a, w2b, t1, t2, alpha = params + if alpha is not None: + scale *= alpha / w1b.size(0) + if t1 is not None: + rebuild1 = cp_weight(w1a, w1b, t1) + else: + rebuild1 = w1a @ w1b + if t2 is not None: + rebuild2 = cp_weight(w2a, w2b, t2) + else: + rebuild2 = w2a @ w2b + rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2 + elif module_type == 'ia3': + weight, on_input = params + if not on_input: + weight = weight.reshape(-1, 1) + merged = orig_weight + weight * orig_weight * scale + del weight, on_input, params + elif module_type == 'kron': + w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params + if alpha is not None and (w1b is not None or w2b is not None): + scale *= alpha / (w1b.size(0) if w1b else w2b.size(0)) + if w1a is not None and w1b is not None: + if t1: + w1 = cp_weight(w1a, w1b, t1) + else: + w1 = w1a @ w1b + if w2a is not None and w2b is not None: + if t2: + w2 = cp_weight(w2a, w2b, t2) + else: + w2 = w2a @ w2b + rebuild = torch.kron(w1, w2).reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild + elif module_type == 'full': + rebuild = params.reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del params, rebuild + + return merged + + +def merge( + base_model, + lyco_state_dict, + scale: float = 1.0, + device='cpu' +): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + merged = 0 + + def merge_state_dict( + prefix, + root_module: torch.nn.Module, + lyco_state_dict: Dict[str, torch.Tensor], + target_replace_modules, + target_replace_names=[] + ): + nonlocal merged + for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', + 'LoRACompatibleConv'}: + continue + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(child_module, 'weight'), scale) + if result is not None: + merged += 1 + child_module.requires_grad_(False) + child_module.weight.copy_(result) + elif name in target_replace_names: + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(module, 'weight'), scale) + if result is not None: + merged += 1 + module.requires_grad_(False) + module.weight.copy_(result) + + if device == 'cpu': + for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'): + lyco_state_dict[k] = v.float() + + merge_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], + lyco_state_dict, + TEXT_ENCODER_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + merge_state_dict( + LORA_PREFIX_UNET, + base_model[2], + lyco_state_dict, + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(f'{merged} Modules been merged') diff --git a/toolkit/metadata.py b/toolkit/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5c36adae70feb2624a84a3b8dbe05f24ed60ed --- /dev/null +++ b/toolkit/metadata.py @@ -0,0 +1,88 @@ +import json +from collections import OrderedDict +from io import BytesIO + +import safetensors +from safetensors import safe_open + +from info import software_meta +from toolkit.train_tools import addnet_hash_legacy +from toolkit.train_tools import addnet_hash_safetensors + + +def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict: + # stringify the meta and reparse OrderedDict to replace [name] with name + meta_string = json.dumps(meta) + if name is not None: + meta_string = meta_string.replace("[name]", name) + save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict) + if add_software_info: + save_meta["software"] = software_meta + # safetensors can only be one level deep + for key, value in save_meta.items(): + # if not float, int, bool, or str, convert to json string + if not isinstance(value, str): + save_meta[key] = json.dumps(value) + # add the pt format + save_meta["format"] = "pt" + return save_meta + + +def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict: + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in meta.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(state_dict, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + meta["sshs_model_hash"] = model_hash + meta["sshs_legacy_hash"] = legacy_hash + return meta + + +def add_base_model_info_to_meta( + meta: OrderedDict, + base_model: str = None, + is_v1: bool = False, + is_v2: bool = False, + is_xl: bool = False, +) -> OrderedDict: + if base_model is not None: + meta['ss_base_model'] = base_model + elif is_v2: + meta['ss_v2'] = True + meta['ss_base_model_version'] = 'sd_2.1' + + elif is_xl: + meta['ss_base_model_version'] = 'sdxl_1.0' + else: + # default to v1.5 + meta['ss_base_model_version'] = 'sd_1.5' + return meta + + +def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: + parsed_meta = OrderedDict() + for key, value in meta.items(): + try: + parsed_meta[key] = json.loads(value) + except json.decoder.JSONDecodeError: + parsed_meta[key] = value + return parsed_meta + + +def load_metadata_from_safetensors(file_path: str) -> OrderedDict: + try: + with safe_open(file_path, framework="pt") as f: + metadata = f.metadata() + return parse_metadata_from_safetensors(metadata) + except Exception as e: + print(f"Error loading metadata from {file_path}: {e}") + return OrderedDict() diff --git a/toolkit/models/DoRA.py b/toolkit/models/DoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..653575e94e640ae2900230d1e3b36f8d3ea5f93e --- /dev/null +++ b/toolkit/models/DoRA.py @@ -0,0 +1,146 @@ +#based off https://github.com/catid/dora/blob/main/dora.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Union, List + +from optimum.quanto import QBytesTensor, QTensor + +from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + +class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): + # def __init__(self, d_in, d_out, rank=4, weight=None, bias=None): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + network: 'LoRASpecialNetwork' = None, + use_bias: bool = False, + **kwargs + ): + self.can_merge_in = False + """if alpha == 0 or None, alpha is rank (no scaling).""" + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.scalar = torch.tensor(1.0) + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ in CONV_MODULES: + raise NotImplementedError("Convolutional layers are not supported yet") + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + # self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える eng: treat as constant + + self.multiplier: Union[float, List[float]] = multiplier + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + d_out = org_module.out_features + d_in = org_module.in_features + + std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float()) + # self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A + # self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B + self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B + # self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev + self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data) + # self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) + # self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) + self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A + # self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data) + self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev + + # m = Magnitude column-wise across output dimension + weight = self.get_orig_weight() + weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype) + lora_weight = self.lora_up.weight @ self.lora_down.weight + weight_norm = self._get_weight_norm(weight, lora_weight) + self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True) + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module + + def get_orig_weight(self): + weight = self.org_module[0].weight + if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): + return weight.dequantize().data.detach() + else: + return weight.data.detach() + + def get_orig_bias(self): + if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: + return self.org_module[0].bias.data.detach() + return None + + # def dora_forward(self, x, *args, **kwargs): + # lora = torch.matmul(self.lora_A, self.lora_B) + # adapted = self.get_orig_weight() + lora + # column_norm = adapted.norm(p=2, dim=0, keepdim=True) + # norm_adapted = adapted / column_norm + # calc_weights = self.magnitude * norm_adapted + # return F.linear(x, calc_weights, self.get_orig_bias()) + + def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaled_lora_weight.to(weight.device) + weight_norm = torch.linalg.norm(weight, dim=1) + return weight_norm + + def apply_dora(self, x, scaled_lora_weight): + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192 + # lora weight is already scaled + + # magnitude = self.lora_magnitude_vector[active_adapter] + weight = self.get_orig_weight() + weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype) + weight_norm = self._get_weight_norm(weight, scaled_lora_weight) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + dora_weight = transpose(weight + scaled_lora_weight, False) + return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight) diff --git a/toolkit/models/LoRAFormer.py b/toolkit/models/LoRAFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..78bb460de413129f7b94782476bbd89b112c8542 --- /dev/null +++ b/toolkit/models/LoRAFormer.py @@ -0,0 +1,267 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) +from ipadapter.ip_adapter.resampler import Resampler +from collections import OrderedDict + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class TransformerBlock(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.ReLU(), + nn.Linear(dim_feedforward, d_model) + ) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, x, cross_attn_input): + # Self-attention + attn_output, _ = self.self_attn(x, x, x) + x = self.norm1(x + attn_output) + + # Cross-attention + cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input) + x = self.norm2(x + cross_attn_output) + + # Feed-forward + ff_output = self.feed_forward(x) + x = self.norm3(x + ff_output) + + return x + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + +# Initialize the network +# num_blocks = 8 +# d_model = 1024 # Adjust as needed +# nhead = 16 # Adjust as needed +# dim_feedforward = 4096 # Adjust as needed +# latent_dim = 1695744 + +class LoRAFormer(torch.nn.Module): + def __init__( + self, + num_blocks, + d_model=1024, + nhead=16, + dim_feedforward=4096, + sd: 'StableDiffusion'=None, + ): + super(LoRAFormer, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + self.latent = nn.Parameter(torch.randn(1, output_size)) + self.latent_proj = nn.Linear(output_size, d_model) + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, nhead, dim_feedforward) + for _ in range(num_blocks) + ]) + self.final_proj = nn.Linear(d_model, output_size) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/toolkit/models/RRDB.py b/toolkit/models/RRDB.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a2ad955309d2d5bb7a19e61812e4a4a761fa2e --- /dev/null +++ b/toolkit/models/RRDB.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import math +import re +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import block as B + +esrgan_safetensors_keys = ['model.0.weight', 'model.0.bias', 'model.1.sub.0.RDB1.conv1.0.weight', + 'model.1.sub.0.RDB1.conv1.0.bias', 'model.1.sub.0.RDB1.conv2.0.weight', + 'model.1.sub.0.RDB1.conv2.0.bias', 'model.1.sub.0.RDB1.conv3.0.weight', + 'model.1.sub.0.RDB1.conv3.0.bias', 'model.1.sub.0.RDB1.conv4.0.weight', + 'model.1.sub.0.RDB1.conv4.0.bias', 'model.1.sub.0.RDB1.conv5.0.weight', + 'model.1.sub.0.RDB1.conv5.0.bias', 'model.1.sub.0.RDB2.conv1.0.weight', + 'model.1.sub.0.RDB2.conv1.0.bias', 'model.1.sub.0.RDB2.conv2.0.weight', + 'model.1.sub.0.RDB2.conv2.0.bias', 'model.1.sub.0.RDB2.conv3.0.weight', + 'model.1.sub.0.RDB2.conv3.0.bias', 'model.1.sub.0.RDB2.conv4.0.weight', + 'model.1.sub.0.RDB2.conv4.0.bias', 'model.1.sub.0.RDB2.conv5.0.weight', + 'model.1.sub.0.RDB2.conv5.0.bias', 'model.1.sub.0.RDB3.conv1.0.weight', + 'model.1.sub.0.RDB3.conv1.0.bias', 'model.1.sub.0.RDB3.conv2.0.weight', + 'model.1.sub.0.RDB3.conv2.0.bias', 'model.1.sub.0.RDB3.conv3.0.weight', + 'model.1.sub.0.RDB3.conv3.0.bias', 'model.1.sub.0.RDB3.conv4.0.weight', + 'model.1.sub.0.RDB3.conv4.0.bias', 'model.1.sub.0.RDB3.conv5.0.weight', + 'model.1.sub.0.RDB3.conv5.0.bias', 'model.1.sub.1.RDB1.conv1.0.weight', + 'model.1.sub.1.RDB1.conv1.0.bias', 'model.1.sub.1.RDB1.conv2.0.weight', + 'model.1.sub.1.RDB1.conv2.0.bias', 'model.1.sub.1.RDB1.conv3.0.weight', + 'model.1.sub.1.RDB1.conv3.0.bias', 'model.1.sub.1.RDB1.conv4.0.weight', + 'model.1.sub.1.RDB1.conv4.0.bias', 'model.1.sub.1.RDB1.conv5.0.weight', + 'model.1.sub.1.RDB1.conv5.0.bias', 'model.1.sub.1.RDB2.conv1.0.weight', + 'model.1.sub.1.RDB2.conv1.0.bias', 'model.1.sub.1.RDB2.conv2.0.weight', + 'model.1.sub.1.RDB2.conv2.0.bias', 'model.1.sub.1.RDB2.conv3.0.weight', + 'model.1.sub.1.RDB2.conv3.0.bias', 'model.1.sub.1.RDB2.conv4.0.weight', + 'model.1.sub.1.RDB2.conv4.0.bias', 'model.1.sub.1.RDB2.conv5.0.weight', + 'model.1.sub.1.RDB2.conv5.0.bias', 'model.1.sub.1.RDB3.conv1.0.weight', + 'model.1.sub.1.RDB3.conv1.0.bias', 'model.1.sub.1.RDB3.conv2.0.weight', + 'model.1.sub.1.RDB3.conv2.0.bias', 'model.1.sub.1.RDB3.conv3.0.weight', + 'model.1.sub.1.RDB3.conv3.0.bias', 'model.1.sub.1.RDB3.conv4.0.weight', + 'model.1.sub.1.RDB3.conv4.0.bias', 'model.1.sub.1.RDB3.conv5.0.weight', + 'model.1.sub.1.RDB3.conv5.0.bias', 'model.1.sub.2.RDB1.conv1.0.weight', + 'model.1.sub.2.RDB1.conv1.0.bias', 'model.1.sub.2.RDB1.conv2.0.weight', + 'model.1.sub.2.RDB1.conv2.0.bias', 'model.1.sub.2.RDB1.conv3.0.weight', + 'model.1.sub.2.RDB1.conv3.0.bias', 'model.1.sub.2.RDB1.conv4.0.weight', + 'model.1.sub.2.RDB1.conv4.0.bias', 'model.1.sub.2.RDB1.conv5.0.weight', + 'model.1.sub.2.RDB1.conv5.0.bias', 'model.1.sub.2.RDB2.conv1.0.weight', + 'model.1.sub.2.RDB2.conv1.0.bias', 'model.1.sub.2.RDB2.conv2.0.weight', + 'model.1.sub.2.RDB2.conv2.0.bias', 'model.1.sub.2.RDB2.conv3.0.weight', + 'model.1.sub.2.RDB2.conv3.0.bias', 'model.1.sub.2.RDB2.conv4.0.weight', + 'model.1.sub.2.RDB2.conv4.0.bias', 'model.1.sub.2.RDB2.conv5.0.weight', + 'model.1.sub.2.RDB2.conv5.0.bias', 'model.1.sub.2.RDB3.conv1.0.weight', + 'model.1.sub.2.RDB3.conv1.0.bias', 'model.1.sub.2.RDB3.conv2.0.weight', + 'model.1.sub.2.RDB3.conv2.0.bias', 'model.1.sub.2.RDB3.conv3.0.weight', + 'model.1.sub.2.RDB3.conv3.0.bias', 'model.1.sub.2.RDB3.conv4.0.weight', + 'model.1.sub.2.RDB3.conv4.0.bias', 'model.1.sub.2.RDB3.conv5.0.weight', + 'model.1.sub.2.RDB3.conv5.0.bias', 'model.1.sub.3.RDB1.conv1.0.weight', + 'model.1.sub.3.RDB1.conv1.0.bias', 'model.1.sub.3.RDB1.conv2.0.weight', + 'model.1.sub.3.RDB1.conv2.0.bias', 'model.1.sub.3.RDB1.conv3.0.weight', + 'model.1.sub.3.RDB1.conv3.0.bias', 'model.1.sub.3.RDB1.conv4.0.weight', + 'model.1.sub.3.RDB1.conv4.0.bias', 'model.1.sub.3.RDB1.conv5.0.weight', + 'model.1.sub.3.RDB1.conv5.0.bias', 'model.1.sub.3.RDB2.conv1.0.weight', + 'model.1.sub.3.RDB2.conv1.0.bias', 'model.1.sub.3.RDB2.conv2.0.weight', + 'model.1.sub.3.RDB2.conv2.0.bias', 'model.1.sub.3.RDB2.conv3.0.weight', + 'model.1.sub.3.RDB2.conv3.0.bias', 'model.1.sub.3.RDB2.conv4.0.weight', + 'model.1.sub.3.RDB2.conv4.0.bias', 'model.1.sub.3.RDB2.conv5.0.weight', + 'model.1.sub.3.RDB2.conv5.0.bias', 'model.1.sub.3.RDB3.conv1.0.weight', + 'model.1.sub.3.RDB3.conv1.0.bias', 'model.1.sub.3.RDB3.conv2.0.weight', + 'model.1.sub.3.RDB3.conv2.0.bias', 'model.1.sub.3.RDB3.conv3.0.weight', + 'model.1.sub.3.RDB3.conv3.0.bias', 'model.1.sub.3.RDB3.conv4.0.weight', + 'model.1.sub.3.RDB3.conv4.0.bias', 'model.1.sub.3.RDB3.conv5.0.weight', + 'model.1.sub.3.RDB3.conv5.0.bias', 'model.1.sub.4.RDB1.conv1.0.weight', + 'model.1.sub.4.RDB1.conv1.0.bias', 'model.1.sub.4.RDB1.conv2.0.weight', + 'model.1.sub.4.RDB1.conv2.0.bias', 'model.1.sub.4.RDB1.conv3.0.weight', + 'model.1.sub.4.RDB1.conv3.0.bias', 'model.1.sub.4.RDB1.conv4.0.weight', + 'model.1.sub.4.RDB1.conv4.0.bias', 'model.1.sub.4.RDB1.conv5.0.weight', + 'model.1.sub.4.RDB1.conv5.0.bias', 'model.1.sub.4.RDB2.conv1.0.weight', + 'model.1.sub.4.RDB2.conv1.0.bias', 'model.1.sub.4.RDB2.conv2.0.weight', + 'model.1.sub.4.RDB2.conv2.0.bias', 'model.1.sub.4.RDB2.conv3.0.weight', + 'model.1.sub.4.RDB2.conv3.0.bias', 'model.1.sub.4.RDB2.conv4.0.weight', + 'model.1.sub.4.RDB2.conv4.0.bias', 'model.1.sub.4.RDB2.conv5.0.weight', + 'model.1.sub.4.RDB2.conv5.0.bias', 'model.1.sub.4.RDB3.conv1.0.weight', + 'model.1.sub.4.RDB3.conv1.0.bias', 'model.1.sub.4.RDB3.conv2.0.weight', + 'model.1.sub.4.RDB3.conv2.0.bias', 'model.1.sub.4.RDB3.conv3.0.weight', + 'model.1.sub.4.RDB3.conv3.0.bias', 'model.1.sub.4.RDB3.conv4.0.weight', + 'model.1.sub.4.RDB3.conv4.0.bias', 'model.1.sub.4.RDB3.conv5.0.weight', + 'model.1.sub.4.RDB3.conv5.0.bias', 'model.1.sub.5.RDB1.conv1.0.weight', + 'model.1.sub.5.RDB1.conv1.0.bias', 'model.1.sub.5.RDB1.conv2.0.weight', + 'model.1.sub.5.RDB1.conv2.0.bias', 'model.1.sub.5.RDB1.conv3.0.weight', + 'model.1.sub.5.RDB1.conv3.0.bias', 'model.1.sub.5.RDB1.conv4.0.weight', + 'model.1.sub.5.RDB1.conv4.0.bias', 'model.1.sub.5.RDB1.conv5.0.weight', + 'model.1.sub.5.RDB1.conv5.0.bias', 'model.1.sub.5.RDB2.conv1.0.weight', + 'model.1.sub.5.RDB2.conv1.0.bias', 'model.1.sub.5.RDB2.conv2.0.weight', + 'model.1.sub.5.RDB2.conv2.0.bias', 'model.1.sub.5.RDB2.conv3.0.weight', + 'model.1.sub.5.RDB2.conv3.0.bias', 'model.1.sub.5.RDB2.conv4.0.weight', + 'model.1.sub.5.RDB2.conv4.0.bias', 'model.1.sub.5.RDB2.conv5.0.weight', + 'model.1.sub.5.RDB2.conv5.0.bias', 'model.1.sub.5.RDB3.conv1.0.weight', + 'model.1.sub.5.RDB3.conv1.0.bias', 'model.1.sub.5.RDB3.conv2.0.weight', + 'model.1.sub.5.RDB3.conv2.0.bias', 'model.1.sub.5.RDB3.conv3.0.weight', + 'model.1.sub.5.RDB3.conv3.0.bias', 'model.1.sub.5.RDB3.conv4.0.weight', + 'model.1.sub.5.RDB3.conv4.0.bias', 'model.1.sub.5.RDB3.conv5.0.weight', + 'model.1.sub.5.RDB3.conv5.0.bias', 'model.1.sub.6.RDB1.conv1.0.weight', + 'model.1.sub.6.RDB1.conv1.0.bias', 'model.1.sub.6.RDB1.conv2.0.weight', + 'model.1.sub.6.RDB1.conv2.0.bias', 'model.1.sub.6.RDB1.conv3.0.weight', + 'model.1.sub.6.RDB1.conv3.0.bias', 'model.1.sub.6.RDB1.conv4.0.weight', + 'model.1.sub.6.RDB1.conv4.0.bias', 'model.1.sub.6.RDB1.conv5.0.weight', + 'model.1.sub.6.RDB1.conv5.0.bias', 'model.1.sub.6.RDB2.conv1.0.weight', + 'model.1.sub.6.RDB2.conv1.0.bias', 'model.1.sub.6.RDB2.conv2.0.weight', + 'model.1.sub.6.RDB2.conv2.0.bias', 'model.1.sub.6.RDB2.conv3.0.weight', + 'model.1.sub.6.RDB2.conv3.0.bias', 'model.1.sub.6.RDB2.conv4.0.weight', + 'model.1.sub.6.RDB2.conv4.0.bias', 'model.1.sub.6.RDB2.conv5.0.weight', + 'model.1.sub.6.RDB2.conv5.0.bias', 'model.1.sub.6.RDB3.conv1.0.weight', + 'model.1.sub.6.RDB3.conv1.0.bias', 'model.1.sub.6.RDB3.conv2.0.weight', + 'model.1.sub.6.RDB3.conv2.0.bias', 'model.1.sub.6.RDB3.conv3.0.weight', + 'model.1.sub.6.RDB3.conv3.0.bias', 'model.1.sub.6.RDB3.conv4.0.weight', + 'model.1.sub.6.RDB3.conv4.0.bias', 'model.1.sub.6.RDB3.conv5.0.weight', + 'model.1.sub.6.RDB3.conv5.0.bias', 'model.1.sub.7.RDB1.conv1.0.weight', + 'model.1.sub.7.RDB1.conv1.0.bias', 'model.1.sub.7.RDB1.conv2.0.weight', + 'model.1.sub.7.RDB1.conv2.0.bias', 'model.1.sub.7.RDB1.conv3.0.weight', + 'model.1.sub.7.RDB1.conv3.0.bias', 'model.1.sub.7.RDB1.conv4.0.weight', + 'model.1.sub.7.RDB1.conv4.0.bias', 'model.1.sub.7.RDB1.conv5.0.weight', + 'model.1.sub.7.RDB1.conv5.0.bias', 'model.1.sub.7.RDB2.conv1.0.weight', + 'model.1.sub.7.RDB2.conv1.0.bias', 'model.1.sub.7.RDB2.conv2.0.weight', + 'model.1.sub.7.RDB2.conv2.0.bias', 'model.1.sub.7.RDB2.conv3.0.weight', + 'model.1.sub.7.RDB2.conv3.0.bias', 'model.1.sub.7.RDB2.conv4.0.weight', + 'model.1.sub.7.RDB2.conv4.0.bias', 'model.1.sub.7.RDB2.conv5.0.weight', + 'model.1.sub.7.RDB2.conv5.0.bias', 'model.1.sub.7.RDB3.conv1.0.weight', + 'model.1.sub.7.RDB3.conv1.0.bias', 'model.1.sub.7.RDB3.conv2.0.weight', + 'model.1.sub.7.RDB3.conv2.0.bias', 'model.1.sub.7.RDB3.conv3.0.weight', + 'model.1.sub.7.RDB3.conv3.0.bias', 'model.1.sub.7.RDB3.conv4.0.weight', + 'model.1.sub.7.RDB3.conv4.0.bias', 'model.1.sub.7.RDB3.conv5.0.weight', + 'model.1.sub.7.RDB3.conv5.0.bias', 'model.1.sub.8.RDB1.conv1.0.weight', + 'model.1.sub.8.RDB1.conv1.0.bias', 'model.1.sub.8.RDB1.conv2.0.weight', + 'model.1.sub.8.RDB1.conv2.0.bias', 'model.1.sub.8.RDB1.conv3.0.weight', + 'model.1.sub.8.RDB1.conv3.0.bias', 'model.1.sub.8.RDB1.conv4.0.weight', + 'model.1.sub.8.RDB1.conv4.0.bias', 'model.1.sub.8.RDB1.conv5.0.weight', + 'model.1.sub.8.RDB1.conv5.0.bias', 'model.1.sub.8.RDB2.conv1.0.weight', + 'model.1.sub.8.RDB2.conv1.0.bias', 'model.1.sub.8.RDB2.conv2.0.weight', + 'model.1.sub.8.RDB2.conv2.0.bias', 'model.1.sub.8.RDB2.conv3.0.weight', + 'model.1.sub.8.RDB2.conv3.0.bias', 'model.1.sub.8.RDB2.conv4.0.weight', + 'model.1.sub.8.RDB2.conv4.0.bias', 'model.1.sub.8.RDB2.conv5.0.weight', + 'model.1.sub.8.RDB2.conv5.0.bias', 'model.1.sub.8.RDB3.conv1.0.weight', + 'model.1.sub.8.RDB3.conv1.0.bias', 'model.1.sub.8.RDB3.conv2.0.weight', + 'model.1.sub.8.RDB3.conv2.0.bias', 'model.1.sub.8.RDB3.conv3.0.weight', + 'model.1.sub.8.RDB3.conv3.0.bias', 'model.1.sub.8.RDB3.conv4.0.weight', + 'model.1.sub.8.RDB3.conv4.0.bias', 'model.1.sub.8.RDB3.conv5.0.weight', + 'model.1.sub.8.RDB3.conv5.0.bias', 'model.1.sub.9.RDB1.conv1.0.weight', + 'model.1.sub.9.RDB1.conv1.0.bias', 'model.1.sub.9.RDB1.conv2.0.weight', + 'model.1.sub.9.RDB1.conv2.0.bias', 'model.1.sub.9.RDB1.conv3.0.weight', + 'model.1.sub.9.RDB1.conv3.0.bias', 'model.1.sub.9.RDB1.conv4.0.weight', + 'model.1.sub.9.RDB1.conv4.0.bias', 'model.1.sub.9.RDB1.conv5.0.weight', + 'model.1.sub.9.RDB1.conv5.0.bias', 'model.1.sub.9.RDB2.conv1.0.weight', + 'model.1.sub.9.RDB2.conv1.0.bias', 'model.1.sub.9.RDB2.conv2.0.weight', + 'model.1.sub.9.RDB2.conv2.0.bias', 'model.1.sub.9.RDB2.conv3.0.weight', + 'model.1.sub.9.RDB2.conv3.0.bias', 'model.1.sub.9.RDB2.conv4.0.weight', + 'model.1.sub.9.RDB2.conv4.0.bias', 'model.1.sub.9.RDB2.conv5.0.weight', + 'model.1.sub.9.RDB2.conv5.0.bias', 'model.1.sub.9.RDB3.conv1.0.weight', + 'model.1.sub.9.RDB3.conv1.0.bias', 'model.1.sub.9.RDB3.conv2.0.weight', + 'model.1.sub.9.RDB3.conv2.0.bias', 'model.1.sub.9.RDB3.conv3.0.weight', + 'model.1.sub.9.RDB3.conv3.0.bias', 'model.1.sub.9.RDB3.conv4.0.weight', + 'model.1.sub.9.RDB3.conv4.0.bias', 'model.1.sub.9.RDB3.conv5.0.weight', + 'model.1.sub.9.RDB3.conv5.0.bias', 'model.1.sub.10.RDB1.conv1.0.weight', + 'model.1.sub.10.RDB1.conv1.0.bias', 'model.1.sub.10.RDB1.conv2.0.weight', + 'model.1.sub.10.RDB1.conv2.0.bias', 'model.1.sub.10.RDB1.conv3.0.weight', + 'model.1.sub.10.RDB1.conv3.0.bias', 'model.1.sub.10.RDB1.conv4.0.weight', + 'model.1.sub.10.RDB1.conv4.0.bias', 'model.1.sub.10.RDB1.conv5.0.weight', + 'model.1.sub.10.RDB1.conv5.0.bias', 'model.1.sub.10.RDB2.conv1.0.weight', + 'model.1.sub.10.RDB2.conv1.0.bias', 'model.1.sub.10.RDB2.conv2.0.weight', + 'model.1.sub.10.RDB2.conv2.0.bias', 'model.1.sub.10.RDB2.conv3.0.weight', + 'model.1.sub.10.RDB2.conv3.0.bias', 'model.1.sub.10.RDB2.conv4.0.weight', + 'model.1.sub.10.RDB2.conv4.0.bias', 'model.1.sub.10.RDB2.conv5.0.weight', + 'model.1.sub.10.RDB2.conv5.0.bias', 'model.1.sub.10.RDB3.conv1.0.weight', + 'model.1.sub.10.RDB3.conv1.0.bias', 'model.1.sub.10.RDB3.conv2.0.weight', + 'model.1.sub.10.RDB3.conv2.0.bias', 'model.1.sub.10.RDB3.conv3.0.weight', + 'model.1.sub.10.RDB3.conv3.0.bias', 'model.1.sub.10.RDB3.conv4.0.weight', + 'model.1.sub.10.RDB3.conv4.0.bias', 'model.1.sub.10.RDB3.conv5.0.weight', + 'model.1.sub.10.RDB3.conv5.0.bias', 'model.1.sub.11.RDB1.conv1.0.weight', + 'model.1.sub.11.RDB1.conv1.0.bias', 'model.1.sub.11.RDB1.conv2.0.weight', + 'model.1.sub.11.RDB1.conv2.0.bias', 'model.1.sub.11.RDB1.conv3.0.weight', + 'model.1.sub.11.RDB1.conv3.0.bias', 'model.1.sub.11.RDB1.conv4.0.weight', + 'model.1.sub.11.RDB1.conv4.0.bias', 'model.1.sub.11.RDB1.conv5.0.weight', + 'model.1.sub.11.RDB1.conv5.0.bias', 'model.1.sub.11.RDB2.conv1.0.weight', + 'model.1.sub.11.RDB2.conv1.0.bias', 'model.1.sub.11.RDB2.conv2.0.weight', + 'model.1.sub.11.RDB2.conv2.0.bias', 'model.1.sub.11.RDB2.conv3.0.weight', + 'model.1.sub.11.RDB2.conv3.0.bias', 'model.1.sub.11.RDB2.conv4.0.weight', + 'model.1.sub.11.RDB2.conv4.0.bias', 'model.1.sub.11.RDB2.conv5.0.weight', + 'model.1.sub.11.RDB2.conv5.0.bias', 'model.1.sub.11.RDB3.conv1.0.weight', + 'model.1.sub.11.RDB3.conv1.0.bias', 'model.1.sub.11.RDB3.conv2.0.weight', + 'model.1.sub.11.RDB3.conv2.0.bias', 'model.1.sub.11.RDB3.conv3.0.weight', + 'model.1.sub.11.RDB3.conv3.0.bias', 'model.1.sub.11.RDB3.conv4.0.weight', + 'model.1.sub.11.RDB3.conv4.0.bias', 'model.1.sub.11.RDB3.conv5.0.weight', + 'model.1.sub.11.RDB3.conv5.0.bias', 'model.1.sub.12.RDB1.conv1.0.weight', + 'model.1.sub.12.RDB1.conv1.0.bias', 'model.1.sub.12.RDB1.conv2.0.weight', + 'model.1.sub.12.RDB1.conv2.0.bias', 'model.1.sub.12.RDB1.conv3.0.weight', + 'model.1.sub.12.RDB1.conv3.0.bias', 'model.1.sub.12.RDB1.conv4.0.weight', + 'model.1.sub.12.RDB1.conv4.0.bias', 'model.1.sub.12.RDB1.conv5.0.weight', + 'model.1.sub.12.RDB1.conv5.0.bias', 'model.1.sub.12.RDB2.conv1.0.weight', + 'model.1.sub.12.RDB2.conv1.0.bias', 'model.1.sub.12.RDB2.conv2.0.weight', + 'model.1.sub.12.RDB2.conv2.0.bias', 'model.1.sub.12.RDB2.conv3.0.weight', + 'model.1.sub.12.RDB2.conv3.0.bias', 'model.1.sub.12.RDB2.conv4.0.weight', + 'model.1.sub.12.RDB2.conv4.0.bias', 'model.1.sub.12.RDB2.conv5.0.weight', + 'model.1.sub.12.RDB2.conv5.0.bias', 'model.1.sub.12.RDB3.conv1.0.weight', + 'model.1.sub.12.RDB3.conv1.0.bias', 'model.1.sub.12.RDB3.conv2.0.weight', + 'model.1.sub.12.RDB3.conv2.0.bias', 'model.1.sub.12.RDB3.conv3.0.weight', + 'model.1.sub.12.RDB3.conv3.0.bias', 'model.1.sub.12.RDB3.conv4.0.weight', + 'model.1.sub.12.RDB3.conv4.0.bias', 'model.1.sub.12.RDB3.conv5.0.weight', + 'model.1.sub.12.RDB3.conv5.0.bias', 'model.1.sub.13.RDB1.conv1.0.weight', + 'model.1.sub.13.RDB1.conv1.0.bias', 'model.1.sub.13.RDB1.conv2.0.weight', + 'model.1.sub.13.RDB1.conv2.0.bias', 'model.1.sub.13.RDB1.conv3.0.weight', + 'model.1.sub.13.RDB1.conv3.0.bias', 'model.1.sub.13.RDB1.conv4.0.weight', + 'model.1.sub.13.RDB1.conv4.0.bias', 'model.1.sub.13.RDB1.conv5.0.weight', + 'model.1.sub.13.RDB1.conv5.0.bias', 'model.1.sub.13.RDB2.conv1.0.weight', + 'model.1.sub.13.RDB2.conv1.0.bias', 'model.1.sub.13.RDB2.conv2.0.weight', + 'model.1.sub.13.RDB2.conv2.0.bias', 'model.1.sub.13.RDB2.conv3.0.weight', + 'model.1.sub.13.RDB2.conv3.0.bias', 'model.1.sub.13.RDB2.conv4.0.weight', + 'model.1.sub.13.RDB2.conv4.0.bias', 'model.1.sub.13.RDB2.conv5.0.weight', + 'model.1.sub.13.RDB2.conv5.0.bias', 'model.1.sub.13.RDB3.conv1.0.weight', + 'model.1.sub.13.RDB3.conv1.0.bias', 'model.1.sub.13.RDB3.conv2.0.weight', + 'model.1.sub.13.RDB3.conv2.0.bias', 'model.1.sub.13.RDB3.conv3.0.weight', + 'model.1.sub.13.RDB3.conv3.0.bias', 'model.1.sub.13.RDB3.conv4.0.weight', + 'model.1.sub.13.RDB3.conv4.0.bias', 'model.1.sub.13.RDB3.conv5.0.weight', + 'model.1.sub.13.RDB3.conv5.0.bias', 'model.1.sub.14.RDB1.conv1.0.weight', + 'model.1.sub.14.RDB1.conv1.0.bias', 'model.1.sub.14.RDB1.conv2.0.weight', + 'model.1.sub.14.RDB1.conv2.0.bias', 'model.1.sub.14.RDB1.conv3.0.weight', + 'model.1.sub.14.RDB1.conv3.0.bias', 'model.1.sub.14.RDB1.conv4.0.weight', + 'model.1.sub.14.RDB1.conv4.0.bias', 'model.1.sub.14.RDB1.conv5.0.weight', + 'model.1.sub.14.RDB1.conv5.0.bias', 'model.1.sub.14.RDB2.conv1.0.weight', + 'model.1.sub.14.RDB2.conv1.0.bias', 'model.1.sub.14.RDB2.conv2.0.weight', + 'model.1.sub.14.RDB2.conv2.0.bias', 'model.1.sub.14.RDB2.conv3.0.weight', + 'model.1.sub.14.RDB2.conv3.0.bias', 'model.1.sub.14.RDB2.conv4.0.weight', + 'model.1.sub.14.RDB2.conv4.0.bias', 'model.1.sub.14.RDB2.conv5.0.weight', + 'model.1.sub.14.RDB2.conv5.0.bias', 'model.1.sub.14.RDB3.conv1.0.weight', + 'model.1.sub.14.RDB3.conv1.0.bias', 'model.1.sub.14.RDB3.conv2.0.weight', + 'model.1.sub.14.RDB3.conv2.0.bias', 'model.1.sub.14.RDB3.conv3.0.weight', + 'model.1.sub.14.RDB3.conv3.0.bias', 'model.1.sub.14.RDB3.conv4.0.weight', + 'model.1.sub.14.RDB3.conv4.0.bias', 'model.1.sub.14.RDB3.conv5.0.weight', + 'model.1.sub.14.RDB3.conv5.0.bias', 'model.1.sub.15.RDB1.conv1.0.weight', + 'model.1.sub.15.RDB1.conv1.0.bias', 'model.1.sub.15.RDB1.conv2.0.weight', + 'model.1.sub.15.RDB1.conv2.0.bias', 'model.1.sub.15.RDB1.conv3.0.weight', + 'model.1.sub.15.RDB1.conv3.0.bias', 'model.1.sub.15.RDB1.conv4.0.weight', + 'model.1.sub.15.RDB1.conv4.0.bias', 'model.1.sub.15.RDB1.conv5.0.weight', + 'model.1.sub.15.RDB1.conv5.0.bias', 'model.1.sub.15.RDB2.conv1.0.weight', + 'model.1.sub.15.RDB2.conv1.0.bias', 'model.1.sub.15.RDB2.conv2.0.weight', + 'model.1.sub.15.RDB2.conv2.0.bias', 'model.1.sub.15.RDB2.conv3.0.weight', + 'model.1.sub.15.RDB2.conv3.0.bias', 'model.1.sub.15.RDB2.conv4.0.weight', + 'model.1.sub.15.RDB2.conv4.0.bias', 'model.1.sub.15.RDB2.conv5.0.weight', + 'model.1.sub.15.RDB2.conv5.0.bias', 'model.1.sub.15.RDB3.conv1.0.weight', + 'model.1.sub.15.RDB3.conv1.0.bias', 'model.1.sub.15.RDB3.conv2.0.weight', + 'model.1.sub.15.RDB3.conv2.0.bias', 'model.1.sub.15.RDB3.conv3.0.weight', + 'model.1.sub.15.RDB3.conv3.0.bias', 'model.1.sub.15.RDB3.conv4.0.weight', + 'model.1.sub.15.RDB3.conv4.0.bias', 'model.1.sub.15.RDB3.conv5.0.weight', + 'model.1.sub.15.RDB3.conv5.0.bias', 'model.1.sub.16.RDB1.conv1.0.weight', + 'model.1.sub.16.RDB1.conv1.0.bias', 'model.1.sub.16.RDB1.conv2.0.weight', + 'model.1.sub.16.RDB1.conv2.0.bias', 'model.1.sub.16.RDB1.conv3.0.weight', + 'model.1.sub.16.RDB1.conv3.0.bias', 'model.1.sub.16.RDB1.conv4.0.weight', + 'model.1.sub.16.RDB1.conv4.0.bias', 'model.1.sub.16.RDB1.conv5.0.weight', + 'model.1.sub.16.RDB1.conv5.0.bias', 'model.1.sub.16.RDB2.conv1.0.weight', + 'model.1.sub.16.RDB2.conv1.0.bias', 'model.1.sub.16.RDB2.conv2.0.weight', + 'model.1.sub.16.RDB2.conv2.0.bias', 'model.1.sub.16.RDB2.conv3.0.weight', + 'model.1.sub.16.RDB2.conv3.0.bias', 'model.1.sub.16.RDB2.conv4.0.weight', + 'model.1.sub.16.RDB2.conv4.0.bias', 'model.1.sub.16.RDB2.conv5.0.weight', + 'model.1.sub.16.RDB2.conv5.0.bias', 'model.1.sub.16.RDB3.conv1.0.weight', + 'model.1.sub.16.RDB3.conv1.0.bias', 'model.1.sub.16.RDB3.conv2.0.weight', + 'model.1.sub.16.RDB3.conv2.0.bias', 'model.1.sub.16.RDB3.conv3.0.weight', + 'model.1.sub.16.RDB3.conv3.0.bias', 'model.1.sub.16.RDB3.conv4.0.weight', + 'model.1.sub.16.RDB3.conv4.0.bias', 'model.1.sub.16.RDB3.conv5.0.weight', + 'model.1.sub.16.RDB3.conv5.0.bias', 'model.1.sub.17.RDB1.conv1.0.weight', + 'model.1.sub.17.RDB1.conv1.0.bias', 'model.1.sub.17.RDB1.conv2.0.weight', + 'model.1.sub.17.RDB1.conv2.0.bias', 'model.1.sub.17.RDB1.conv3.0.weight', + 'model.1.sub.17.RDB1.conv3.0.bias', 'model.1.sub.17.RDB1.conv4.0.weight', + 'model.1.sub.17.RDB1.conv4.0.bias', 'model.1.sub.17.RDB1.conv5.0.weight', + 'model.1.sub.17.RDB1.conv5.0.bias', 'model.1.sub.17.RDB2.conv1.0.weight', + 'model.1.sub.17.RDB2.conv1.0.bias', 'model.1.sub.17.RDB2.conv2.0.weight', + 'model.1.sub.17.RDB2.conv2.0.bias', 'model.1.sub.17.RDB2.conv3.0.weight', + 'model.1.sub.17.RDB2.conv3.0.bias', 'model.1.sub.17.RDB2.conv4.0.weight', + 'model.1.sub.17.RDB2.conv4.0.bias', 'model.1.sub.17.RDB2.conv5.0.weight', + 'model.1.sub.17.RDB2.conv5.0.bias', 'model.1.sub.17.RDB3.conv1.0.weight', + 'model.1.sub.17.RDB3.conv1.0.bias', 'model.1.sub.17.RDB3.conv2.0.weight', + 'model.1.sub.17.RDB3.conv2.0.bias', 'model.1.sub.17.RDB3.conv3.0.weight', + 'model.1.sub.17.RDB3.conv3.0.bias', 'model.1.sub.17.RDB3.conv4.0.weight', + 'model.1.sub.17.RDB3.conv4.0.bias', 'model.1.sub.17.RDB3.conv5.0.weight', + 'model.1.sub.17.RDB3.conv5.0.bias', 'model.1.sub.18.RDB1.conv1.0.weight', + 'model.1.sub.18.RDB1.conv1.0.bias', 'model.1.sub.18.RDB1.conv2.0.weight', + 'model.1.sub.18.RDB1.conv2.0.bias', 'model.1.sub.18.RDB1.conv3.0.weight', + 'model.1.sub.18.RDB1.conv3.0.bias', 'model.1.sub.18.RDB1.conv4.0.weight', + 'model.1.sub.18.RDB1.conv4.0.bias', 'model.1.sub.18.RDB1.conv5.0.weight', + 'model.1.sub.18.RDB1.conv5.0.bias', 'model.1.sub.18.RDB2.conv1.0.weight', + 'model.1.sub.18.RDB2.conv1.0.bias', 'model.1.sub.18.RDB2.conv2.0.weight', + 'model.1.sub.18.RDB2.conv2.0.bias', 'model.1.sub.18.RDB2.conv3.0.weight', + 'model.1.sub.18.RDB2.conv3.0.bias', 'model.1.sub.18.RDB2.conv4.0.weight', + 'model.1.sub.18.RDB2.conv4.0.bias', 'model.1.sub.18.RDB2.conv5.0.weight', + 'model.1.sub.18.RDB2.conv5.0.bias', 'model.1.sub.18.RDB3.conv1.0.weight', + 'model.1.sub.18.RDB3.conv1.0.bias', 'model.1.sub.18.RDB3.conv2.0.weight', + 'model.1.sub.18.RDB3.conv2.0.bias', 'model.1.sub.18.RDB3.conv3.0.weight', + 'model.1.sub.18.RDB3.conv3.0.bias', 'model.1.sub.18.RDB3.conv4.0.weight', + 'model.1.sub.18.RDB3.conv4.0.bias', 'model.1.sub.18.RDB3.conv5.0.weight', + 'model.1.sub.18.RDB3.conv5.0.bias', 'model.1.sub.19.RDB1.conv1.0.weight', + 'model.1.sub.19.RDB1.conv1.0.bias', 'model.1.sub.19.RDB1.conv2.0.weight', + 'model.1.sub.19.RDB1.conv2.0.bias', 'model.1.sub.19.RDB1.conv3.0.weight', + 'model.1.sub.19.RDB1.conv3.0.bias', 'model.1.sub.19.RDB1.conv4.0.weight', + 'model.1.sub.19.RDB1.conv4.0.bias', 'model.1.sub.19.RDB1.conv5.0.weight', + 'model.1.sub.19.RDB1.conv5.0.bias', 'model.1.sub.19.RDB2.conv1.0.weight', + 'model.1.sub.19.RDB2.conv1.0.bias', 'model.1.sub.19.RDB2.conv2.0.weight', + 'model.1.sub.19.RDB2.conv2.0.bias', 'model.1.sub.19.RDB2.conv3.0.weight', + 'model.1.sub.19.RDB2.conv3.0.bias', 'model.1.sub.19.RDB2.conv4.0.weight', + 'model.1.sub.19.RDB2.conv4.0.bias', 'model.1.sub.19.RDB2.conv5.0.weight', + 'model.1.sub.19.RDB2.conv5.0.bias', 'model.1.sub.19.RDB3.conv1.0.weight', + 'model.1.sub.19.RDB3.conv1.0.bias', 'model.1.sub.19.RDB3.conv2.0.weight', + 'model.1.sub.19.RDB3.conv2.0.bias', 'model.1.sub.19.RDB3.conv3.0.weight', + 'model.1.sub.19.RDB3.conv3.0.bias', 'model.1.sub.19.RDB3.conv4.0.weight', + 'model.1.sub.19.RDB3.conv4.0.bias', 'model.1.sub.19.RDB3.conv5.0.weight', + 'model.1.sub.19.RDB3.conv5.0.bias', 'model.1.sub.20.RDB1.conv1.0.weight', + 'model.1.sub.20.RDB1.conv1.0.bias', 'model.1.sub.20.RDB1.conv2.0.weight', + 'model.1.sub.20.RDB1.conv2.0.bias', 'model.1.sub.20.RDB1.conv3.0.weight', + 'model.1.sub.20.RDB1.conv3.0.bias', 'model.1.sub.20.RDB1.conv4.0.weight', + 'model.1.sub.20.RDB1.conv4.0.bias', 'model.1.sub.20.RDB1.conv5.0.weight', + 'model.1.sub.20.RDB1.conv5.0.bias', 'model.1.sub.20.RDB2.conv1.0.weight', + 'model.1.sub.20.RDB2.conv1.0.bias', 'model.1.sub.20.RDB2.conv2.0.weight', + 'model.1.sub.20.RDB2.conv2.0.bias', 'model.1.sub.20.RDB2.conv3.0.weight', + 'model.1.sub.20.RDB2.conv3.0.bias', 'model.1.sub.20.RDB2.conv4.0.weight', + 'model.1.sub.20.RDB2.conv4.0.bias', 'model.1.sub.20.RDB2.conv5.0.weight', + 'model.1.sub.20.RDB2.conv5.0.bias', 'model.1.sub.20.RDB3.conv1.0.weight', + 'model.1.sub.20.RDB3.conv1.0.bias', 'model.1.sub.20.RDB3.conv2.0.weight', + 'model.1.sub.20.RDB3.conv2.0.bias', 'model.1.sub.20.RDB3.conv3.0.weight', + 'model.1.sub.20.RDB3.conv3.0.bias', 'model.1.sub.20.RDB3.conv4.0.weight', + 'model.1.sub.20.RDB3.conv4.0.bias', 'model.1.sub.20.RDB3.conv5.0.weight', + 'model.1.sub.20.RDB3.conv5.0.bias', 'model.1.sub.21.RDB1.conv1.0.weight', + 'model.1.sub.21.RDB1.conv1.0.bias', 'model.1.sub.21.RDB1.conv2.0.weight', + 'model.1.sub.21.RDB1.conv2.0.bias', 'model.1.sub.21.RDB1.conv3.0.weight', + 'model.1.sub.21.RDB1.conv3.0.bias', 'model.1.sub.21.RDB1.conv4.0.weight', + 'model.1.sub.21.RDB1.conv4.0.bias', 'model.1.sub.21.RDB1.conv5.0.weight', + 'model.1.sub.21.RDB1.conv5.0.bias', 'model.1.sub.21.RDB2.conv1.0.weight', + 'model.1.sub.21.RDB2.conv1.0.bias', 'model.1.sub.21.RDB2.conv2.0.weight', + 'model.1.sub.21.RDB2.conv2.0.bias', 'model.1.sub.21.RDB2.conv3.0.weight', + 'model.1.sub.21.RDB2.conv3.0.bias', 'model.1.sub.21.RDB2.conv4.0.weight', + 'model.1.sub.21.RDB2.conv4.0.bias', 'model.1.sub.21.RDB2.conv5.0.weight', + 'model.1.sub.21.RDB2.conv5.0.bias', 'model.1.sub.21.RDB3.conv1.0.weight', + 'model.1.sub.21.RDB3.conv1.0.bias', 'model.1.sub.21.RDB3.conv2.0.weight', + 'model.1.sub.21.RDB3.conv2.0.bias', 'model.1.sub.21.RDB3.conv3.0.weight', + 'model.1.sub.21.RDB3.conv3.0.bias', 'model.1.sub.21.RDB3.conv4.0.weight', + 'model.1.sub.21.RDB3.conv4.0.bias', 'model.1.sub.21.RDB3.conv5.0.weight', + 'model.1.sub.21.RDB3.conv5.0.bias', 'model.1.sub.22.RDB1.conv1.0.weight', + 'model.1.sub.22.RDB1.conv1.0.bias', 'model.1.sub.22.RDB1.conv2.0.weight', + 'model.1.sub.22.RDB1.conv2.0.bias', 'model.1.sub.22.RDB1.conv3.0.weight', + 'model.1.sub.22.RDB1.conv3.0.bias', 'model.1.sub.22.RDB1.conv4.0.weight', + 'model.1.sub.22.RDB1.conv4.0.bias', 'model.1.sub.22.RDB1.conv5.0.weight', + 'model.1.sub.22.RDB1.conv5.0.bias', 'model.1.sub.22.RDB2.conv1.0.weight', + 'model.1.sub.22.RDB2.conv1.0.bias', 'model.1.sub.22.RDB2.conv2.0.weight', + 'model.1.sub.22.RDB2.conv2.0.bias', 'model.1.sub.22.RDB2.conv3.0.weight', + 'model.1.sub.22.RDB2.conv3.0.bias', 'model.1.sub.22.RDB2.conv4.0.weight', + 'model.1.sub.22.RDB2.conv4.0.bias', 'model.1.sub.22.RDB2.conv5.0.weight', + 'model.1.sub.22.RDB2.conv5.0.bias', 'model.1.sub.22.RDB3.conv1.0.weight', + 'model.1.sub.22.RDB3.conv1.0.bias', 'model.1.sub.22.RDB3.conv2.0.weight', + 'model.1.sub.22.RDB3.conv2.0.bias', 'model.1.sub.22.RDB3.conv3.0.weight', + 'model.1.sub.22.RDB3.conv3.0.bias', 'model.1.sub.22.RDB3.conv4.0.weight', + 'model.1.sub.22.RDB3.conv4.0.bias', 'model.1.sub.22.RDB3.conv5.0.weight', + 'model.1.sub.22.RDB3.conv5.0.bias', 'model.1.sub.23.weight', 'model.1.sub.23.bias', + 'model.3.weight', 'model.3.bias', 'model.6.weight', 'model.6.bias', 'model.8.weight', + 'model.8.bias', 'model.10.weight', 'model.10.bias'] + + +# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py +# Which enhanced stuff that was already here +class RRDBNet(nn.Module): + def __init__( + self, + state_dict, + norm=None, + act: str = "leakyrelu", + upsampler: str = "upconv", + mode: B.ConvMode = "CNA", + ) -> None: + """ + ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. + By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, + and Chen Change Loy. + This is old-arch Residual in Residual Dense Block Network and is not + the newest revision that's available at github.com/xinntao/ESRGAN. + This is on purpose, the newest Network has severely limited the + potential use of the Network with no benefits. + This network supports model files from both new and old-arch. + Args: + norm: Normalization layer + act: Activation layer + upsampler: Upsample layer. upconv, pixel_shuffle + mode: Convolution mode + """ + super(RRDBNet, self).__init__() + self.model_arch = "ESRGAN" + self.sub_type = "SR" + + self.state = state_dict + self.norm = norm + self.act = act + self.upsampler = upsampler + self.mode = mode + + self.state_map = { + # currently supports old, new, and newer RRDBNet arch models + # ESRGAN, BSRGAN/RealSR, Real-ESRGAN + "model.0.weight": ("conv_first.weight",), + "model.0.bias": ("conv_first.bias",), + "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), + "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), + r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( + r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", + r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", + ), + } + if "params_ema" in self.state: + self.state = self.state["params_ema"] + # self.model_arch = "RealESRGAN" + self.num_blocks = self.get_num_blocks() + self.plus = any("conv1x1" in k for k in self.state.keys()) + if self.plus: + self.model_arch = "ESRGAN+" + + self.state = self.new_to_old_arch(self.state) + + self.key_arr = list(self.state.keys()) + + self.in_nc: int = self.state[self.key_arr[0]].shape[1] + self.out_nc: int = self.state[self.key_arr[-1]].shape[0] + + self.scale: int = self.get_scale() + self.num_filters: int = self.state[self.key_arr[0]].shape[0] + + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + + self.supports_fp16 = True + self.supports_bfp16 = True + self.min_size_restriction = None + + # Detect if pixelunshuffle was used (Real-ESRGAN) + if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( + self.in_nc / 4, + self.in_nc / 16, + ): + self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) + else: + self.shuffle_factor = None + + upsample_block = { + "upconv": B.upconv_block, + "pixel_shuffle": B.pixelshuffle_block, + }.get(self.upsampler) + if upsample_block is None: + raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") + + if self.scale == 3: + upsample_blocks = upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + upscale_factor=3, + act_type=self.act, + c2x2=c2x2, + ) + else: + upsample_blocks = [ + upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, + ) + for _ in range(int(math.log(self.scale, 2))) + ] + + self.model = B.sequential( + # fea conv + B.conv_block( + in_nc=self.in_nc, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + B.ShortcutBlock( + B.sequential( + # rrdb blocks + *[ + B.RRDB( + nf=self.num_filters, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=self.norm, + act_type=self.act, + mode="CNA", + plus=self.plus, + c2x2=c2x2, + ) + for _ in range(self.num_blocks) + ], + # lr conv + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=self.norm, + act_type=None, + mode=self.mode, + c2x2=c2x2, + ), + ) + ), + *upsample_blocks, + # hr_conv0 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=self.act, + c2x2=c2x2, + ), + # hr_conv1 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.out_nc, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + ) + + # Adjust these properties for calculations outside of the model + if self.shuffle_factor: + self.in_nc //= self.shuffle_factor ** 2 + self.scale //= self.shuffle_factor + + self.load_state_dict(self.state, strict=False) + + def new_to_old_arch(self, state): + """Convert a new-arch model state dictionary to an old-arch dictionary.""" + if "params_ema" in state: + state = state["params_ema"] + + if "conv_first.weight" not in state: + # model is already old arch, this is a loose check, but should be sufficient + return state + + # add nb to state keys + for kind in ("weight", "bias"): + self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ + f"model.1.sub./NB/.{kind}" + ] + del self.state_map[f"model.1.sub./NB/.{kind}"] + + old_state = OrderedDict() + for old_key, new_keys in self.state_map.items(): + for new_key in new_keys: + if r"\1" in old_key: + for k, v in state.items(): + sub = re.sub(new_key, old_key, k) + if sub != k: + old_state[sub] = v + else: + if new_key in state: + old_state[old_key] = state[new_key] + + # upconv layers + max_upconv = 0 + for key in state.keys(): + match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) + if match is not None: + _, key_num, key_type = match.groups() + old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] + max_upconv = max(max_upconv, int(key_num) * 3) + + # final layers + for key in state.keys(): + if key in ("HRconv.weight", "conv_hr.weight"): + old_state[f"model.{max_upconv + 2}.weight"] = state[key] + elif key in ("HRconv.bias", "conv_hr.bias"): + old_state[f"model.{max_upconv + 2}.bias"] = state[key] + elif key in ("conv_last.weight",): + old_state[f"model.{max_upconv + 4}.weight"] = state[key] + elif key in ("conv_last.bias",): + old_state[f"model.{max_upconv + 4}.bias"] = state[key] + + # Sort by first numeric value of each layer + def compare(item1, item2): + parts1 = item1.split(".") + parts2 = item2.split(".") + int1 = int(parts1[1]) + int2 = int(parts2[1]) + return int1 - int2 + + sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) + + # Rebuild the output dict in the right order + out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) + + return out_dict + + def get_scale(self, min_part: int = 6) -> int: + n = 0 + for part in list(self.state): + parts = part.split(".")[1:] + if len(parts) == 2: + part_num = int(parts[0]) + if part_num > min_part and parts[1] == "weight": + n += 1 + return 2 ** n + + def get_num_blocks(self) -> int: + nbs = [] + state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( + r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", + ) + for state_key in state_keys: + for k in self.state: + m = re.search(state_key, k) + if m: + nbs.append(int(m.group(1))) + if nbs: + break + return max(*nbs) + 1 + + def forward(self, x): + if self.shuffle_factor: + _, _, h, w = x.size() + mod_pad_h = ( + self.shuffle_factor - h % self.shuffle_factor + ) % self.shuffle_factor + mod_pad_w = ( + self.shuffle_factor - w % self.shuffle_factor + ) % self.shuffle_factor + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") + x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) + x = self.model(x) + return x[:, :, : h * self.scale, : w * self.scale] + return self.model(x) diff --git a/toolkit/models/auraflow.py b/toolkit/models/auraflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e2539bda489ccc1975f42b9c9a027076f8fdfc74 --- /dev/null +++ b/toolkit/models/auraflow.py @@ -0,0 +1,127 @@ +import math +from functools import partial + +from torch import nn +import torch + + +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + +# comfy +# def apply_pos_embeds(self, x, h, w): +# h = (h + 1) // self.patch_size +# w = (w + 1) // self.patch_size +# max_dim = max(h, w) +# +# cur_dim = self.h_max +# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) +# +# if max_dim > cur_dim: +# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, +# -1) +# cur_dim = max_dim +# +# from_h = (cur_dim - h) // 2 +# from_w = (cur_dim - w) // 2 +# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w] +# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) + + # def patchify(self, x): + # B, C, H, W = x.size() + # pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + # pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + # + # x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + # x = x.view( + # B, + # C, + # (H + 1) // self.patch_size, + # self.patch_size, + # (W + 1) // self.patch_size, + # self.patch_size, + # ) + # x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + # return x + +def patch_auraflow_pos_embed(pos_embed): + # we need to hijack the forward and replace with a custom one. Self is the model + def new_forward(self, latent): + batch_size, num_channels, height, width = latent.size() + + # add padding to the latent to make it match pos_embed + latent_size = height * width * num_channels / 16 # todo check where 16 comes from? + pos_embed_size = self.pos_embed.shape[1] + if latent_size < pos_embed_size: + total_padding = int(pos_embed_size - math.floor(latent_size)) + total_padding = total_padding // 16 + pad_height = total_padding // 2 + pad_width = total_padding - pad_height + # mirror padding on the right side + padding = (0, pad_width, 0, pad_height) + latent = torch.nn.functional.pad(latent, padding, mode='reflect') + elif latent_size > pos_embed_size: + amount_to_remove = latent_size - pos_embed_size + latent = latent[:, :, :-amount_to_remove] + + batch_size, num_channels, height, width = latent.size() + + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + pos_embed.forward = partial(new_forward, pos_embed) diff --git a/toolkit/models/block.py b/toolkit/models/block.py new file mode 100644 index 0000000000000000000000000000000000000000..76356b5e3eb7c7d6dc4ed1629aac318c264111c5 --- /dev/null +++ b/toolkit/models/block.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import OrderedDict + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn + + +#################### +# Basic blocks +#################### + + +def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): + # helper selecting activation + # neg_slope: for leakyrelu and init of prelu + # n_prelu: for p_relu num_parameters + act_type = act_type.lower() + if act_type == "relu": + layer = nn.ReLU(inplace) + elif act_type == "leakyrelu": + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == "prelu": + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + else: + raise NotImplementedError( + "activation layer [{:s}] is not found".format(act_type) + ) + return layer + + +def norm(norm_type: str, nc: int): + # helper selecting normalization layer + norm_type = norm_type.lower() + if norm_type == "batch": + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == "instance": + layer = nn.InstanceNorm2d(nc, affine=False) + else: + raise NotImplementedError( + "normalization layer [{:s}] is not found".format(norm_type) + ) + return layer + + +def pad(pad_type: str, padding): + # helper selecting padding layer + # if padding is 'zero', do by conv layers + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == "reflect": + layer = nn.ReflectionPad2d(padding) + elif pad_type == "replicate": + layer = nn.ReplicationPad2d(padding) + else: + raise NotImplementedError( + "padding layer [{:s}] is not implemented".format(pad_type) + ) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = "Identity .. \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlockSPSR(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlockSPSR, self).__init__() + self.sub = submodule + + def forward(self, x): + return x, self.sub + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +def sequential(*args): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +ConvMode = Literal["CNA", "NAC", "CNAC"] + + +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + +def conv_block( + in_nc: int, + out_nc: int, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type: str | None = "relu", + mode: ConvMode = "CNA", + c2x2=False, +): + """ + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None + padding = padding if pad_type == "zero" else 0 + + c = nn.Conv2d( + in_nc, + out_nc, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=groups, + ) + a = act(act_type) if act_type else None + if mode in ("CNA", "CNAC"): + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == "NAC": + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + # Important! + # input----ReLU(inplace)----Conv--+----output + # |________________________| + # inplace ReLU will modify the input, therefore wrong output + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + else: + assert False, f"Invalid conv mode {mode}" + + +#################### +# Useful blocks +#################### + + +class ResNetBlock(nn.Module): + """ + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + """ + + def __init__( + self, + in_nc, + mid_nc, + out_nc, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type=None, + act_type="relu", + mode: ConvMode = "CNA", + res_scale=1, + ): + super(ResNetBlock, self).__init__() + conv0 = conv_block( + in_nc, + mid_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + if mode == "CNA": + act_type = None + if mode == "CNAC": # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block( + mid_nc, + out_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + # if in_nc != out_nc: + # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ + # None, None) + # print('Need a projecter in ResNetBlock.') + # else: + # self.project = lambda x:x + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x): + res = self.res(x).mul(self.res_scale) + return x + res + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__( + self, + nf, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + _convtype="Conv2D", + _spectral_norm=False, + plus=False, + c2x2=False, + ): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB2 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB3 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + + Args: + nf (int): Channel number of intermediate features (num_feat). + gc (int): Channels for each growth (num_grow_ch: growth channel, + i.e. intermediate channels). + convtype (str): the type of convolution to use. Default: 'Conv2D' + gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new + trainable parameters) + plus (bool): enable the additional residual paths from ESRGAN+ + (adds trainable parameters) + """ + + def __init__( + self, + nf=64, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + plus=False, + c2x2=False, + ): + super(ResidualDenseBlock_5C, self).__init__() + + ## + + self.conv1x1 = conv1x1(nf, gc) if plus else None + ## + + + self.conv1 = conv_block( + nf, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv2 = conv_block( + nf + gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv3 = conv_block( + nf + 2 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv4 = conv_block( + nf + 3 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + if mode == "CNA": + last_act = None + else: + last_act = act_type + self.conv5 = conv_block( + nf + 4 * gc, + nf, + 3, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=last_act, + mode=mode, + c2x2=c2x2, + ) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + # pylint: disable=not-callable + x2 = x2 + self.conv1x1(x) # + + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 # + + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# Upsampler +#################### + + +def pixelshuffle_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", +): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block( + in_nc, + out_nc * (upscale_factor ** 2), + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=None, + act_type=None, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", + mode="nearest", + c2x2=False, +): + # Up conv + # described in https://distill.pub/2016/deconv-checkerboard/ + # convert to float 16 if is bfloat16 + upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block( + in_nc, + out_nc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + c2x2=c2x2, + ) + return sequential(upsample, conv) diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f4346fd5ac3eae4c8d91e50df586acc8d4cd2fbe --- /dev/null +++ b/toolkit/models/clip_fusion.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn + +from toolkit.models.zipper_resampler import ContextualAlphaMask + + +# Conv1d MLP +# MLP that can alternately be used as a conv1d on dim 1 +class MLPC(nn.Module): + def __init__( + self, + in_dim, + out_dim, + hidden_dim, + do_conv=False, + use_residual=True + ): + super().__init__() + self.do_conv = do_conv + if use_residual: + assert in_dim == out_dim + # dont normalize if using conv + if not do_conv: + self.layernorm = nn.LayerNorm(in_dim) + + if do_conv: + self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1) + self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1) + else: + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + if not self.do_conv: + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperBlock(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + # permute to (batch_size, out_size, in_tokens) + + self.zip_token = MLPC( + in_dim=self.in_tokens, + out_dim=self.out_tokens, + hidden_dim=self.hidden_tokens, + do_conv=True, # no need to permute + use_residual=False + ) + + # permute to (batch_size, out_tokens, out_size) + + # in shpae: (batch_size, in_tokens, in_size) + self.zip_size = MLPC( + in_dim=self.in_size, + out_dim=self.out_size, + hidden_dim=self.hidden_size, + use_residual=False + ) + + def forward(self, x): + x = self.zip_token(x) + x = self.zip_size(x) + return x + + + + + + +# CLIPFusionModule +# Fuses any size of vision and text embeddings into a single embedding. +# remaps tokens and vectors. +class CLIPFusionModule(nn.Module): + def __init__( + self, + text_hidden_size: int = 768, + text_tokens: int = 77, + vision_hidden_size: int = 1024, + vision_tokens: int = 257, + num_blocks: int = 1, + ): + super(CLIPFusionModule, self).__init__() + + self.text_hidden_size = text_hidden_size + self.text_tokens = text_tokens + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + + self.resampler = ZipperBlock( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.vision_hidden_size * 2, + hidden_tokens=self.vision_tokens * 2 + ) + + self.zipper_blocks = torch.nn.ModuleList([ + ZipperBlock( + in_size=self.text_hidden_size * 2, + in_tokens=self.text_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.text_hidden_size * 2, + hidden_tokens=self.text_tokens * 2 + ) for i in range(num_blocks) + ]) + + self.ctx_alpha = ContextualAlphaMask( + dim=self.text_hidden_size, + ) + + self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01) + + def forward(self, text_embeds, vision_embeds): + # text_embeds = (batch_size, 77, 768) + # vision_embeds = (batch_size, 257, 1024) + # output = (batch_size, 77, 768) + + vision_embeds = self.resampler(vision_embeds) + x = vision_embeds + for i, block in enumerate(self.zipper_blocks): + res = x + x = torch.cat([text_embeds, x], dim=-1) + x = block(x) + x = x + res + + # alpha mask + ctx_alpha = self.ctx_alpha(text_embeds) + # reshape alpha to (1, 77, 1) + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + + x = ctx_alpha * x * alpha + + x = x + text_embeds + + return x diff --git a/toolkit/models/clip_pre_processor.py b/toolkit/models/clip_pre_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..7956da0b4d0a5d6b882d4d19d9458bf409cc9b39 --- /dev/null +++ b/toolkit/models/clip_pre_processor.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + + +class UpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_in = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.GELU() + ) + self.conv_up = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.GELU() + ) + + self.conv_out = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + ) + + def forward(self, x): + x = self.conv_in(x) + x = self.conv_up(x) + x = self.conv_out(x) + return x + + +class CLIPImagePreProcessor(nn.Module): + def __init__( + self, + input_size=896, + clip_input_size=224, + downscale_factor: int = 16, + ): + super().__init__() + # make sure they are evenly divisible + assert input_size % clip_input_size == 0 + in_channels = 3 + + self.input_size = input_size + self.clip_input_size = clip_input_size + self.downscale_factor = downscale_factor + + subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768 + channels = subpixel_channels + + upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4 + + num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2 + + # make the residual down up blocks + self.upsample_blocks = nn.ModuleList() + self.subpixel_blocks = nn.ModuleList() + current_channels = channels + current_downscale = downscale_factor + for _ in range(num_upsample_blocks): + # determine the reshuffled channel count for this dimension + output_downscale = current_downscale // 2 + out_channels = in_channels * output_downscale ** 2 + # out_channels = current_channels // 2 + self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) + current_channels = out_channels + current_downscale = output_downscale + self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale)) + + # (bs, 768, 56, 56) -> (bs, 192, 112, 112) + # (bs, 192, 112, 112) -> (bs, 48, 224, 224) + + self.conv_out = nn.Conv2d( + current_channels, + out_channels=3, + kernel_size=3, + padding=1 + ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224) + + # do a pooling layer to downscale the input to 1/3 of the size + # (bs, 3, 896, 896) -> (bs, 3, 224, 224) + kernel_size = input_size // clip_input_size + self.res_down = nn.AvgPool2d( + kernel_size=kernel_size, + stride=kernel_size + ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224) + + # make a blending for output residual with near 0 weight + self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56) + + self.conv_in = nn.Sequential( + nn.Conv2d( + subpixel_channels, + channels, + kernel_size=3, + padding=1 + ), + nn.GELU() + ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56) + + # make 2 deep blocks + + def forward(self, x): + inputs = x + # resize to input_size x input_size + x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') + + res = self.res_down(inputs) + + x = self.unshuffle(x) + x = self.conv_in(x) + for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks): + x = up(x) + block_res = subpixel(inputs) + x = x + block_res + x = self.conv_out(x) + # blend residual + x = x * self.res_blend + res + return x diff --git a/toolkit/models/decorator.py b/toolkit/models/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..63f45aa9f944370727eed1a362c9bb04ad99fa9b --- /dev/null +++ b/toolkit/models/decorator.py @@ -0,0 +1,33 @@ +import torch + + +class Decorator(torch.nn.Module): + def __init__( + self, + num_tokens: int = 4, + token_size: int = 4096, + ) -> None: + super().__init__() + + self.weight: torch.nn.Parameter = torch.nn.Parameter( + torch.randn(num_tokens, token_size) + ) + # ensure it is float32 + self.weight.data = self.weight.data.float() + + def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor: + # make sure the param is float32 + if self.weight.dtype != text_embeds.dtype: + self.weight.data = self.weight.data.float() + # expand batch to match text_embeds + batch_size = text_embeds.shape[0] + decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1) + if is_unconditional: + # zero pad the decorator embeds + decorator_embeds = torch.zeros_like(decorator_embeds) + + if decorator_embeds.dtype != text_embeds.dtype: + decorator_embeds = decorator_embeds.to(text_embeds.dtype) + text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2) + + return text_embeds diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..48ce8786ca86c77c01d73fb6ff8875056ee9d4bc --- /dev/null +++ b/toolkit/models/flux.py @@ -0,0 +1,35 @@ + +# forward that bypasses the guidance embedding so it can be avoided during training. +from functools import partial + + +def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + return conditioning + +# bypass the forward function + + +def bypass_flux_guidance(transformer): + if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + # dont bypass if it doesnt have the guidance embedding + if not hasattr(transformer.time_text_embed, 'guidance_embedder'): + return + transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward + transformer.time_text_embed.forward = partial( + guidance_embed_bypass_forward, transformer.time_text_embed + ) + +# restore the forward function + + +def restore_flux_guidance(transformer): + if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward + del transformer.time_text_embed._bfg_orig_forward diff --git a/toolkit/models/flux_sage_attn.py b/toolkit/models/flux_sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..930a17000c92e7bef14d141c22e9acfef8f92bf4 --- /dev/null +++ b/toolkit/models/flux_sage_attn.py @@ -0,0 +1,94 @@ +from typing import Optional +from diffusers.models.attention_processor import Attention +import torch +import torch.nn.functional as F + + +class FluxSageAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + from sageattention import sageattn + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py new file mode 100644 index 0000000000000000000000000000000000000000..33613ed3193249ebefdea4fc3ff470bdd12a5a27 --- /dev/null +++ b/toolkit/models/ilora.py @@ -0,0 +1,364 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) +from ipadapter.ip_adapter.resampler import Resampler +from collections import OrderedDict + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(dropout) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.dropout(x) + if self.use_residual: + x = x + residual + return x + +class LoRAGenerator(torch.nn.Module): + def __init__( + self, + input_size: int = 768, # projection dimension + hidden_size: int = 768, + head_size: int = 512, + num_heads: int = 1, + num_mlp_layers: int = 1, + output_size: int = 768, + dropout: float = 0.0 + ): + super().__init__() + self.input_size = input_size + self.num_heads = num_heads + self.simple = False + + self.output_size = output_size + + if self.simple: + self.head = nn.Linear(input_size, head_size, bias=False) + else: + self.lin_in = nn.Linear(input_size, hidden_size) + + self.mlp_blocks = nn.Sequential(*[ + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) + ]) + self.head = nn.Linear(hidden_size, head_size, bias=False) + self.norm = nn.LayerNorm(head_size) + + if num_heads == 1: + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + else: + head_output_size = output_size // num_heads + self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)]) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + for output in self.outputs: + output.weight.data *= 0.01 + + # allow get device + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, embedding): + if len(embedding.shape) == 2: + embedding = embedding.unsqueeze(1) + + x = embedding + + if not self.simple: + x = self.lin_in(embedding) + x = self.mlp_blocks(x) + x = self.head(x) + x = self.norm(x) + + if self.num_heads == 1: + x = self.output(x) + else: + out_chunks = torch.chunk(x, self.num_heads, dim=1) + x = [] + for out_layer, chunk in zip(self.outputs, out_chunks): + x.append(out_layer(chunk)) + x = torch.cat(x, dim=-1) + + return x.squeeze(1) + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + org_module = self.lora_module_ref().orig_module_ref() + stride = org_module.stride + padding = org_module.padding + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + + +class InstantLoRAModule(torch.nn.Module): + def __init__( + self, + vision_hidden_size: int, + vision_tokens: int, + head_dim: int, + num_heads: int, # number of heads in the resampler + sd: 'StableDiffusion', + config=None + ): + super(InstantLoRAModule, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + self.head_dim = head_dim + self.num_heads = num_heads + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + number_formatted_output_size = "{:,}".format(output_size) + + print(f" ILORA output size: {number_formatted_output_size}") + + # if not evenly divisible, error + if self.output_size % self.num_heads != 0: + raise ValueError("Output size must be divisible by the number of heads") + + self.head_output_size = self.output_size // self.num_heads + + if vision_tokens > 1: + self.resampler = Resampler( + dim=vision_hidden_size, + depth=4, + dim_head=64, + heads=12, + num_queries=num_heads, # output tokens + embedding_dim=vision_hidden_size, + max_seq_len=vision_tokens, + output_dim=head_dim, + apply_pos_emb=True, # this is new + ff_mult=4 + ) + + self.proj_module = LoRAGenerator( + input_size=head_dim, + hidden_size=head_dim, + head_size=head_dim, + num_mlp_layers=1, + num_heads=self.num_heads, + output_size=self.output_size, + ) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py new file mode 100644 index 0000000000000000000000000000000000000000..c46bd0a6d51a0e15856217a92c9aa27cf304287d --- /dev/null +++ b/toolkit/models/ilora2.py @@ -0,0 +1,419 @@ +import math +import weakref + +from toolkit.config_modules import AdapterConfig +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from toolkit.paths import REPOS_ROOT + +sys.path.append(REPOS_ROOT) +from ipadapter.ip_adapter.resampler import Resampler +from collections import OrderedDict + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(dropout) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.dropout(x) + if self.use_residual: + x = x + residual + return x + + +class LoRAGenerator(torch.nn.Module): + def __init__( + self, + input_size: int = 768, # projection dimension + hidden_size: int = 768, + head_size: int = 512, + num_heads: int = 1, + num_mlp_layers: int = 1, + output_size: int = 768, + dropout: float = 0.0 + ): + super().__init__() + self.input_size = input_size + self.num_heads = num_heads + self.simple = False + + self.output_size = output_size + + if self.simple: + self.head = nn.Linear(input_size, head_size, bias=False) + else: + self.lin_in = nn.Linear(input_size, hidden_size) + + self.mlp_blocks = nn.Sequential(*[ + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in + range(num_mlp_layers) + ]) + self.head = nn.Linear(hidden_size, head_size, bias=False) + self.norm = nn.LayerNorm(head_size) + + if num_heads == 1: + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + else: + head_output_size = output_size // num_heads + self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)]) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + for output in self.outputs: + output.weight.data *= 0.01 + + # allow get device + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, embedding): + if len(embedding.shape) == 2: + embedding = embedding.unsqueeze(1) + + x = embedding + + if not self.simple: + x = self.lin_in(embedding) + x = self.mlp_blocks(x) + x = self.head(x) + x = self.norm(x) + + if self.num_heads == 1: + x = self.output(x) + else: + out_chunks = torch.chunk(x, self.num_heads, dim=1) + x = [] + for out_layer, chunk in zip(self.outputs, out_chunks): + x.append(out_layer(chunk)) + x = torch.cat(x, dim=-1) + + return x.squeeze(1) + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.do_up = instant_lora_module.config.ilora_up + self.do_down = instant_lora_module.config.ilora_down + self.do_mid = instant_lora_module.config.ilora_mid + + self.down_dim = self.down_shape[1] if self.do_down else 0 + self.mid_dim = self.up_shape[1] if self.do_mid else 0 + self.out_dim = self.up_shape[0] if self.do_up else 0 + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + if not self.do_down: + return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + down_weight = self.embed[:, :self.down_dim] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + try: + if len(x.shape) == 4: + # conv + down_weight = down_weight.view(batch_size, -1, 1, 1) + if x.shape[1] != down_weight.shape[1]: + raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") + elif len(x.shape) == 2: + down_weight = down_weight.view(batch_size, -1) + if x.shape[1] != down_weight.shape[1]: + raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") + else: + down_weight = down_weight.view(batch_size, 1, -1) + if x.shape[2] != down_weight.shape[2]: + raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") + x = x * down_weight + x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) + except Exception as e: + print(e) + raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") + + return x + + def up_forward(self, x, *args, **kwargs): + # do mid here + x = self.mid_forward(x, *args, **kwargs) + if not self.do_up: + return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + up_weight = self.embed[:, -self.out_dim:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + try: + if len(x.shape) == 4: + # conv + up_weight = up_weight.view(batch_size, -1, 1, 1) + elif len(x.shape) == 2: + up_weight = up_weight.view(batch_size, -1) + else: + up_weight = up_weight.view(batch_size, 1, -1) + x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) + x = x * up_weight + except Exception as e: + print(e) + raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") + + return x + + def mid_forward(self, x, *args, **kwargs): + if not self.do_mid: + return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) + batch_size = x.shape[0] + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim] + + # unconditional + if mid_weight.shape[0] * 2 == batch_size: + mid_weight = torch.cat([mid_weight] * 2, dim=0) + + weight_chunks = torch.chunk(mid_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + if len(x_chunk.shape) == 4: + # conv + weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1) + else: + weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + +class InstantLoRAModule(torch.nn.Module): + def __init__( + self, + vision_hidden_size: int, + vision_tokens: int, + head_dim: int, + num_heads: int, # number of heads in the resampler + sd: 'StableDiffusion', + config: AdapterConfig + ): + super(InstantLoRAModule, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + self.head_dim = head_dim + self.num_heads = num_heads + + self.config: AdapterConfig = config + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + # + # module_size = math.prod(down_shape) + math.prod(up_shape) + + # conv weight shape is (out_channels, in_channels, kernel_size, kernel_size) + # linear weight shape is (out_features, in_features) + + # just doing in dim and out dim + in_dim = down_shape[1] if self.config.ilora_down else 0 + mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0 + out_dim = up_shape[0] if self.config.ilora_up else 0 + module_size = in_dim + mid_dim + out_dim + + output_size += module_size + self.embed_lengths.append(module_size) + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.orig_forward = lora_module.lora_down.forward + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.orig_forward = lora_module.lora_up.forward + lora_module.lora_up.forward = instant_module.up_forward + + self.output_size = output_size + + number_formatted_output_size = "{:,}".format(output_size) + + print(f" ILORA output size: {number_formatted_output_size}") + + # if not evenly divisible, error + if self.output_size % self.num_heads != 0: + raise ValueError("Output size must be divisible by the number of heads") + + self.head_output_size = self.output_size // self.num_heads + + if vision_tokens > 1: + self.resampler = Resampler( + dim=vision_hidden_size, + depth=4, + dim_head=64, + heads=12, + num_queries=num_heads, # output tokens + embedding_dim=vision_hidden_size, + max_seq_len=vision_tokens, + output_dim=head_dim, + apply_pos_emb=True, # this is new + ff_mult=4 + ) + + self.proj_module = LoRAGenerator( + input_size=head_dim, + hidden_size=head_dim, + head_size=head_dim, + num_mlp_layers=1, + num_heads=self.num_heads, + output_size=self.output_size, + ) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start + length]) + start += length + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + "do_up": self.config.ilora_up, + "do_mid": self.config.ilora_mid, + "do_down": self.config.ilora_down, + } diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..815f33101ffb4cfc672cefde6f94fd01ae198783 --- /dev/null +++ b/toolkit/models/pixtral_vision.py @@ -0,0 +1,618 @@ +import math +from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING +import os +import torch +import torch.nn as nn +from dataclasses import dataclass +from huggingface_hub import snapshot_download +from safetensors.torch import load_file +import json + +if TYPE_CHECKING: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, **kwargs): + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # type: ignore + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: + keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) + values = torch.repeat_interleave(values, repeats=repeats, dim=dim) + return keys, values + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[:, None, :] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + head_dim: int, + n_kv_heads: int, + **kwargs, + ): + super().__init__() + + self.n_heads: int = n_heads + self.head_dim: int = head_dim + self.n_kv_heads: int = n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.head_dim ** -0.5 + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Optional[Any] = None, + mask: Optional['BlockDiagonalMask'] = None, + ) -> torch.Tensor: + from xformers.ops.fmha import memory_efficient_attention + assert mask is None or cache is None + seqlen_sum, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) + xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if cache is None: + key, val = xk, xv + elif cache.prefill: + key, val = cache.interleave_kv(xk, xv) + cache.update(xk, xv) + else: + cache.update(xk, xv) + key, val = cache.key, cache.value + key = key.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) + val = val.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) + + # Repeat keys and values to match number of query heads + key, val = repeat_kv(key, val, self.repeats, dim=1) + + # xformers requires (B=1, S, H, D) + xq, key, val = xq[None, ...], key[None, ...], val[None, ...] + output = memory_efficient_attention( + xq, key, val, mask if cache is None else cache.mask) + output = output.view(seqlen_sum, self.n_heads * self.head_dim) + + assert isinstance(output, torch.Tensor) + + return self.wo(output) # type: ignore + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + norm_eps: float, + **kwargs, + ): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.attention = Attention( + dim=dim, + n_heads=n_heads, + head_dim=head_dim, + n_kv_heads=n_kv_heads, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + self.feed_forward: nn.Module + self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Optional[Any] = None, + mask: Optional['BlockDiagonalMask'] = None, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float = 1e4 # for rope-2D + image_token_id: int = 10 + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by + (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) + return positions + + +class PixtralVisionEncoder(nn.Module): + def __init__( + self, + hidden_size: int = 1024, + num_channels: int = 3, + image_size: int = 1024, + patch_size: int = 16, + intermediate_size: int = 4096, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + rope_theta: float = 1e4, # for rope-2D + image_token_id: int = 10, + **kwargs, + ): + super().__init__() + self.args = VisionEncoderArgs( + hidden_size=hidden_size, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + image_token_id=image_token_id, + ) + args = self.args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = VisionTransformerBlocks(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': + if os.path.isdir(pretrained_model_name_or_path): + model_folder = pretrained_model_name_or_path + else: + model_folder = snapshot_download(pretrained_model_name_or_path) + + # make sure there is a config + if not os.path.exists(os.path.join(model_folder, "config.json")): + raise ValueError(f"Could not find config.json in {model_folder}") + + # load config + with open(os.path.join(model_folder, "config.json"), "r") as f: + config = json.load(f) + + model = cls(**config) + + # see if there is a state_dict + if os.path.exists(os.path.join(model_folder, "model.safetensors")): + state_dict = load_file(os.path.join( + model_folder, "model.safetensors")) + model.load_state_dict(state_dict) + + return model + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: List[torch.Tensor], + ) -> torch.Tensor: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + """ + Args: + images: list of N_img images of variable sizes, each of shape (C, H, W) + + Returns: + image_features: tensor of token features for all tokens of all images of + shape (N_toks, D) + """ + assert isinstance( + images, list), f"Expected list of images, got {type(images)}" + assert all(len(img.shape) == 3 for img in + images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}" + # pass images through initial convolution independently + patch_embeds_list = [self.patch_conv( + img.unsqueeze(0)).squeeze(0) for img in images] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).permute(1, 0) + for p in patch_embeds_list], dim=0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # remove batch dimension of the single sequence + return out # type: ignore[no-any-return] + + +class VisionLanguageAdapter(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.w_in = nn.Linear( + in_dim, + out_dim, + bias=True, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(out_dim, out_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # type: ignore[no-any-return] + return self.w_out(self.gelu(self.w_in(x))) + + +class VisionTransformerBlocks(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + dim=args.hidden_size, + hidden_dim=args.intermediate_size, + n_heads=args.num_attention_heads, + n_kv_heads=args.num_attention_heads, + head_dim=args.hidden_size // args.num_attention_heads, + norm_eps=1e-5, + ) + ) + + def forward( + self, + x: torch.Tensor, + mask: 'BlockDiagonalMask', + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB +DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB + + +def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """ + Normalize a tensor image with mean and standard deviation. + + Args: + image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1]. + mean (torch.Tensor): Mean for each channel. + std (torch.Tensor): Standard deviation for each channel. + + Returns: + torch.Tensor: Normalized image with shape (C, H, W). + """ + assert image.shape[0] == len(mean) == len( + std), f"{image.shape=}, {mean.shape=}, {std.shape=}" + + # Reshape mean and std to (C, 1, 1) for broadcasting + mean = mean.view(-1, 1, 1) + std = std.view(-1, 1, 1) + + return (image - mean) / std + + +def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor: + """ + Resize and normalize the input image. + + Args: + image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1]. + new_size (tuple[int, int]): Target size (height, width) for resizing. + + Returns: + torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W). + """ + # Resize the image + resized_image = torch.nn.functional.interpolate( + image.unsqueeze(0), + size=new_size, + mode='bicubic', + align_corners=False + ).squeeze(0) + + # Normalize the image + normalized_image = normalize( + resized_image, + torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype), + torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype) + ) + + return normalized_image + + +class PixtralVisionImagePreprocessor: + def __init__(self, image_patch_size=16, max_image_size=1024) -> None: + self.image_patch_size = image_patch_size + self.max_image_size = max_image_size + self.image_token = 10 + + def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]: + w: Union[int, float] + h: Union[int, float] + + if max_image_size is None: + max_image_size = self.max_image_size + + w, h = img.shape[-1], img.shape[-2] + + # originally, pixtral used the largest of the 2 dimensions, but we + # will use the base size of the image based on number of pixels. + # ratio = max(h / self.max_image_size, w / self.max_image_size) # original + + base_size = int(math.sqrt(w * h)) + ratio = base_size / max_image_size + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + + width_tokens = (w - 1) // self.image_patch_size + 1 + height_tokens = (h - 1) // self.image_patch_size + 1 + + return width_tokens, height_tokens + + def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor: + """ + Converts ImageChunks to numpy image arrays and image token ids + + Args: + image torch tensor with values 0-1 and shape of (C, H, W) + + Returns: + processed_image: tensor of token features for all tokens of all images of + """ + # should not have batch + if len(image.shape) == 4: + raise ValueError( + f"Expected image with shape (C, H, W), got {image.shape}") + + if image.min() < 0.0 or image.max() > 1.0: + raise ValueError( + f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") + + if max_image_size is None: + max_image_size = self.max_image_size + + w, h = self._image_to_num_tokens(image, max_image_size=max_image_size) + assert w > 0 + assert h > 0 + + new_image_size = ( + w * self.image_patch_size, + h * self.image_patch_size, + ) + + processed_image = transform_image(image, new_image_size) + + return processed_image + + +class PixtralVisionImagePreprocessorCompatibleReturn: + def __init__(self, pixel_values) -> None: + self.pixel_values = pixel_values + + +# Compatable version with ai toolkit flow +class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): + def __init__(self, image_patch_size=16, max_image_size=1024) -> None: + super().__init__( + image_patch_size=image_patch_size, + max_image_size=max_image_size + ) + self.size = { + 'height': max_image_size, + 'width': max_image_size + } + self.max_image_size = max_image_size + self.image_mean = DATASET_MEAN + self.image_std = DATASET_STD + + def __call__( + self, + images, + return_tensors="pt", + do_resize=True, + do_rescale=False, + max_image_size=None, + ) -> torch.Tensor: + if max_image_size is None: + max_image_size = self.max_image_size + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + processed_image = super().__call__(image, max_image_size=max_image_size) + out_stack.append(processed_image) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionImagePreprocessorCompatibleReturn(output) + + +class PixtralVisionEncoderCompatibleReturn: + def __init__(self, hidden_states) -> None: + self.hidden_states = hidden_states + + +class PixtralVisionEncoderCompatibleConfig: + def __init__(self): + self.image_size = 1024 + self.hidden_size = 1024 + self.patch_size = 16 + + +class PixtralVisionEncoderCompatible(PixtralVisionEncoder): + def __init__( + self, + hidden_size: int = 1024, + num_channels: int = 3, + image_size: int = 1024, + patch_size: int = 16, + intermediate_size: int = 4096, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + rope_theta: float = 1e4, # for rope-2D + image_token_id: int = 10, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + image_token_id=image_token_id, + ) + self.config = PixtralVisionEncoderCompatibleConfig() + + def forward( + self, + images, + output_hidden_states=True, + ) -> torch.Tensor: + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + # must be in an array + image_output = super().forward([image]) + out_stack.append(image_output) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionEncoderCompatibleReturn([output]) diff --git a/toolkit/models/redux.py b/toolkit/models/redux.py new file mode 100644 index 0000000000000000000000000000000000000000..609ac50ae7f1404cfd85c63532339fcf94ae60c3 --- /dev/null +++ b/toolkit/models/redux.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = nn.Linear( + txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + x = self.redux_up(sigclip_embeds) + x = torch.nn.functional.silu(x) + + projected_x = self.redux_down(x) + return projected_x diff --git a/toolkit/models/single_value_adapter.py b/toolkit/models/single_value_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9284d02093c0be9fd16d5effb794a590d8449a4c --- /dev/null +++ b/toolkit/models/single_value_adapter.py @@ -0,0 +1,402 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING + +from diffusers import Transformer2DModel +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class SingleValueAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] < batch_size: + # doing cfg + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # needs to be shape (batch, 1, 1) + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SingleValueAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + num_values: int = 1, + ): + super(SingleValueAdapter, self).__init__() + is_pixart = sd.is_pixart + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.token_size = num_values + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + # if is_pixart: + # to_k_bias = unet_sd[layer_name + ".to_k.bias"] + # to_v_bias = unet_sd[layer_name + ".to_v.bias"] + # else: + # to_k_bias = None + # to_v_bias = None + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.token_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + # if is_pixart: + # to_k_bias = torch.cat([ + # to_k_bias, + # torch.zeros(self.token_size - to_k_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + # to_v_bias = torch.cat([ + # to_v_bias, + # torch.zeros(self.token_size - to_v_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + elif to_k_adapter.shape[1] > self.token_size: + to_k_adapter = to_k_adapter[:, :self.token_size] + to_v_adapter = to_v_adapter[:, :self.token_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.token_size] + # to_v_bias = to_v_bias[:self.token_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias + + weights = { + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, + } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias + + attn_procs[name] = SingleValueAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.token_size, + has_bias=False, + ) + attn_procs[name].load_state_dict(weights) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, input): + return input diff --git a/toolkit/models/size_agnostic_feature_encoder.py b/toolkit/models/size_agnostic_feature_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a716aec504503afa2c103506876c91bc2b617d07 --- /dev/null +++ b/toolkit/models/size_agnostic_feature_encoder.py @@ -0,0 +1,256 @@ +import os +from typing import Union, Optional + +import torch +import torch.nn as nn +from transformers.image_processing_utils import BaseImageProcessor + + +class SAFEReducerBlock(nn.Module): + """ + This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative + So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual + along for the ride. This is done to preserve information. + """ + def __init__(self, channels=512): + super(SAFEReducerBlock, self).__init__() + self.channels = channels + + activation = nn.GELU + + self.reducer = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + activation(), + nn.BatchNorm2d(channels), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + activation(), + nn.BatchNorm2d(channels), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + res = self.residual_shrink(x) + reduced = self.reducer(x) + return reduced + res + + +class SizeAgnosticFeatureEncoder(nn.Module): + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + ): + super(SizeAgnosticFeatureEncoder, self).__init__() + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.channels = channels + self.reducer_channels = reducer_channels + self.gradient_checkpointing = False + + # input is minimum of (bs, 3, 256, 256) + + subpixel_channels = in_channels * downscale_factor ** 2 + + # PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32) + # PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32) + + self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32) + + # run as many times as needed to get to min feature of 8 on the smallest dimension + self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8) + + self.reduced_out = nn.Conv2d( + reducer_channels, self.channels, kernel_size=3, padding=1 + ) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8) + + # (bs, 2048, 8, 8) + self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4) + self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2) + + # reduce mean of dims 2 and 3 + self.adaptive_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + ) + + # (bs, 2048) + # linear layer to (bs, self.num_vectors * self.num_tokens) + self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens) + + # (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144) + + def forward(self, x): + x = self.unshuffle(x) + x = self.conv_in(x) + + while True: + # reduce until we get as close to 8x8 as possible without going under + x = self.reducer(x) + if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8: + break + + x = self.reduced_out(x) + x = self.block1(x) + x = self.block2(x) + x = self.adaptive_pool(x) + x = self.fc1(x) + + # reshape + x = x.view(-1, self.num_tokens, self.num_vectors) + + return x + + +class SAFEIPReturn: + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + +class SAFEImageProcessor(BaseImageProcessor): + def __init__( + self, + max_size=1024, + min_size=256, + **kwargs + ): + super().__init__(**kwargs) + self.max_size = max_size + self.min_size = min_size + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + # not needed + return cls(**kwargs) + + def __call__( + self, + images, + **kwargs + ): + # TODO allow for random resizing + # comes in 0 - 1 range + # if any size is smaller than 256, resize to 256 + # if any size is larger than max_size, resize to max_size + if images.min() < -0.3 or images.max() > 1.3: + raise ValueError( + "images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format( + images.min(), images.max() + )) + + # make sure we have (bs, 3, h, w) + while len(images.shape) < 4: + images = images.unsqueeze(0) + + # expand to 3 channels if we only have 1 channel + if images.shape[1] == 1: + images = torch.cat([images, images, images], dim=1) + + width = images.shape[3] + height = images.shape[2] + + if width < self.min_size or height < self.min_size: + # scale up so that the smallest size is 256 + if width < height: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + else: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + elif width > self.max_size or height > self.max_size: + # scale down so that the largest size is max_size but do not shrink the other size below 256 + if width > height: + new_width = self.max_size + new_height = int(height * (self.max_size / width)) + else: + new_height = self.max_size + new_width = int(width * (self.max_size / height)) + + if new_width < self.min_size: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + + if new_height < self.min_size: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + # if wither side is not divisible by 16, mirror pad to make it so + if images.shape[2] % 16 != 0: + pad = 16 - (images.shape[2] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect') + if images.shape[3] % 16 != 0: + pad = 16 - (images.shape[3] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect') + + return SAFEIPReturn(images) + + +class SAFEVMConfig: + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + **kwargs + ): + self.in_channels = in_channels + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.reducer_channels = reducer_channels + self.channels = channels + self.downscale_factor = downscale_factor + self.image_size = 224 + + self.hidden_size = num_vectors + self.projection_dim = num_vectors + + +class SAFEVMReturn: + def __init__(self, output): + self.output = output + # todo actually do hidden states. This is just for code compatability for now + self.hidden_states = [output for _ in range(13)] + + +class SAFEVisionModel(SizeAgnosticFeatureEncoder): + def __init__(self, **kwargs): + self.config = SAFEVMConfig(**kwargs) + self.image_size = None + # super().__init__(**kwargs) + super(SAFEVisionModel, self).__init__(**kwargs) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # not needed + return SAFEVisionModel(**kwargs) + + def forward(self, x, **kwargs): + return SAFEVMReturn(super().forward(x)) diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7679aac14f803ef58f6ad3ed078076234c5695 --- /dev/null +++ b/toolkit/models/te_adapter.py @@ -0,0 +1,460 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING + + +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection +from diffusers.models.embeddings import PixArtAlphaTextProjection + +from toolkit import train_tools +from toolkit.paths import REPOS_ROOT +from toolkit.prompt_utils import PromptEmbeds +from diffusers import Transformer2DModel + +sys.path.append(REPOS_ROOT) + +from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline + from toolkit.custom_adapter import CustomAdapter + + +class TEAdapterCaptionProjection(nn.Module): + def __init__(self, caption_channels, adapter: 'TEAdapter'): + super().__init__() + in_features = caption_channels + self.adapter_ref: weakref.ref = weakref.ref(adapter) + sd = adapter.sd_ref() + self.parent_module_ref = weakref.ref(sd.unet.caption_projection) + parent_module = self.parent_module_ref() + self.linear_1 = nn.Linear( + in_features=in_features, + out_features=parent_module.linear_1.out_features, + bias=True + ) + self.linear_2 = nn.Linear( + in_features=parent_module.linear_2.in_features, + out_features=parent_module.linear_2.out_features, + bias=True + ) + + # save the orig forward + parent_module.linear_1.orig_forward = parent_module.linear_1.forward + parent_module.linear_2.orig_forward = parent_module.linear_2.forward + + # replace original forward + parent_module.orig_forward = parent_module.forward + parent_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def forward(self, caption): + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds.text_embeds + # check if we are doing unconditional + if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]: + # concat unconditional to match the hidden state batch size + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) + else: + unconditional = self.unconditional_embeds.text_embeds + adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) + hidden_states = self.linear_1(adapter_hidden_states) + hidden_states = self.parent_module_ref().act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + else: + return self.parent_module_ref().orig_forward(caption) + + +class TEAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None, + adapter_hidden_size=None, layer_name=None): + super().__init__() + self.layer_name = layer_name + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) + + @property + def is_active(self): + return self.adapter_ref().is_active + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds.text_embeds + # check if we are doing unconditional + if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]: + # concat unconditional to match the hidden state batch size + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) + else: + unconditional = self.unconditional_embeds.text_embeds + adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) + # for ip-adapter + key = self.to_k_adapter(adapter_hidden_states) + value = self.to_v_adapter(adapter_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + try: + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + except RuntimeError: + raise RuntimeError(f"key shape: {key.shape}, value shape: {value.shape}") + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # remove attn mask if doing clip + if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip": + attention_mask = None + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class TEAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + te: Union[T5EncoderModel], + tokenizer: CLIPTokenizer + ): + super(TEAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.te_ref: weakref.ref = weakref.ref(te) + self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + self.adapter_modules = [] + self.caption_projection = None + self.embeds_store = [] + is_pixart = sd.is_pixart + + if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5": + self.token_size = self.te_ref().config.d_model + else: + self.token_size = self.te_ref().config.hidden_size + + # add text projection if is sdxl + self.text_projection = None + if sd.is_xl: + clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0] + self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False) + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_dict_map = { + + } + module_idx = 0 + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + attn_processor_names = [] + + blocks = [] + transformer_blocks = [] + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.token_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + elif to_k_adapter.shape[1] > self.token_size: + to_k_adapter = to_k_adapter[:, :self.token_size] + to_v_adapter = to_v_adapter[:, :self.token_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + + # todo resize to the TE hidden size + weights = { + "to_k_adapter.weight": to_k_adapter, + "to_v_adapter.weight": to_v_adapter, + } + + if self.sd_ref().is_pixart: + # pixart is much more sensitive + weights = { + "to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01, + "to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01, + } + + attn_procs[name] = TEAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.adapter_ref().config.num_tokens, + adapter=self, + adapter_hidden_size=self.token_size, + layer_name=layer_name + ) + attn_procs[name].load_state_dict(weights) + self.adapter_modules.append(attn_procs[name]) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + self.caption_projection = TEAdapterCaptionProjection( + caption_channels=self.token_size, + adapter=self, + ) + + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def encode_text(self, text): + te: T5EncoderModel = self.te_ref() + tokenizer: T5Tokenizer = self.tokenizer_ref() + attn_mask_float = None + + # input_ids = tokenizer( + # text, + # max_length=77, + # padding="max_length", + # truncation=True, + # return_tensors="pt", + # ).input_ids.to(te.device) + # outputs = te(input_ids=input_ids) + # outputs = outputs.last_hidden_state + if self.adapter_ref().config.text_encoder_arch == "clip": + embeds = train_tools.encode_prompts( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + attention_mask = torch.ones(embeds.shape[:2], device=embeds.device) + + elif self.adapter_ref().config.text_encoder_arch == "pile-t5": + # just use aura pile + embeds, attention_mask = train_tools.encode_prompts_auraflow( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + + else: + embeds, attention_mask = train_tools.encode_prompts_pixart( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + if attention_mask is not None: + attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) + if self.text_projection is not None: + # pool the output of embeds ignoring 0 in the attention mask + if attn_mask_float is not None: + pooled_output = embeds * attn_mask_float.unsqueeze(-1) + else: + pooled_output = embeds + + # reduce along dim 1 while maintaining batch and dim 2 + pooled_output_sum = pooled_output.sum(dim=1) + + if attn_mask_float is not None: + attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1) + + pooled_output = pooled_output_sum / attn_mask_sum + + pooled_embeds = self.text_projection(pooled_output) + + prompt_embeds = PromptEmbeds( + (embeds, pooled_embeds), + attention_mask=attention_mask, + ).detach() + + else: + + prompt_embeds = PromptEmbeds( + embeds, + attention_mask=attention_mask, + ).detach() + + return prompt_embeds + + + + def forward(self, input): + return input diff --git a/toolkit/models/te_aug_adapter.py b/toolkit/models/te_aug_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..02cbbec1a6eb4fdcce7d94067a976fbde496f89c --- /dev/null +++ b/toolkit/models/te_aug_adapter.py @@ -0,0 +1,253 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING, Optional, Tuple + +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer +from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention + +from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule +from toolkit.paths import REPOS_ROOT +from toolkit.resampler import Resampler + +sys.path.append(REPOS_ROOT) + +from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +class TEAugAdapterCLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.attn_module_ref: weakref.ref = weakref.ref(attn_module) + self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + # copy the weights from the original module + self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01 + self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01 + #reset the bias + self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001 + self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001 + + self.zipper = ZipperModule( + in_size=attn_module.embed_dim, + in_tokens=77 * 2, + out_size=attn_module.embed_dim, + out_tokens=77, + hidden_size=attn_module.embed_dim, + hidden_tokens=77, + ) + # self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data) + # self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data) + # #reset the bias + # self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data) + # self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data) + + # replace the original forward with our forward + self.original_forward = attn_module.forward + attn_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + attn_module = self.attn_module_ref() + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = attn_module.q_proj(hidden_states) * attn_module.scale + key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz) + value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim) + query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref() + if self.adapter_ref().is_active and adapter.conditional_embeds is not None: + # apply the adapter + + if adapter.is_unconditional_run: + embeds = adapter.unconditional_embeds + else: + embeds = adapter.conditional_embeds + # if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well + if embeds.size(0) != bsz: + embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0) + + key_states_raw = self.k_proj_adapter(embeds) + key_states = attn_module._shape(key_states_raw, -1, bsz) + value_states_raw = self.v_proj_adapter(embeds) + value_states = attn_module._shape(value_states_raw, -1, bsz) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + attn_output_adapter = torch.bmm(attn_probs, value_states) + + if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output_adapter.size()}" + ) + + attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output_adapter = attn_output_adapter.transpose(1, 2) + attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim) + + attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1)) + + # attn_output_adapter = attn_module.out_proj(attn_output_adapter) + attn_output = attn_output + attn_output_adapter + + attn_output = attn_module.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + +class TEAugAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + ): + super(TEAugAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + + if isinstance(sd.text_encoder, list): + raise ValueError("Dual text encoders is not yet supported") + + # dim will come from text encoder + # dim = sd.unet.config['cross_attention_dim'] + text_encoder: CLIPTextModel = sd.text_encoder + dim = text_encoder.config.hidden_size + + clip_encoder: CLIPEncoder = text_encoder.text_model.encoder + # dim = clip_encoder.layers[-1].self_attn + + if hasattr(adapter.vision_encoder.config, 'hidden_sizes'): + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + else: + embedding_dim = adapter.vision_encoder.config.hidden_size + + image_encoder_state_dict = adapter.vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if adapter.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + + out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens + self.image_proj_model = ZipperModule( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=dim, + hidden_tokens=out_tokens, + ) + # init adapter modules + attn_procs = {} + for idx, layer in enumerate(clip_encoder.layers): + name = f"clip_attention.{idx}" + attn_procs[name] = TEAugAdapterCLIPAttention( + layer.self_attn, + self + ) + + self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values())) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + + def forward(self, input): + # # apply the adapter + input = self.image_proj_model(input) + # self.embeds = input + return input diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3f9bc757143570067c2387ec8a6a96c909690d --- /dev/null +++ b/toolkit/models/vd_adapter.py @@ -0,0 +1,812 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING, Optional +from collections import OrderedDict + +from diffusers import Transformer2DModel, FluxTransformer2DModel +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection +from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter +from transformers import SiglipImageProcessor, SiglipVisionModel + +from toolkit.config_modules import AdapterConfig +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +# matches distribution of randn +class Norm(nn.Module): + def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6): + super(Norm, self).__init__() + self.target_mean = target_mean + self.target_std = target_std + self.eps = eps + + def forward(self, x): + dims = tuple(range(1, x.dim())) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True) + + # Normalize + return self.target_std * (x - mean) / (std + self.eps) + self.target_mean + + +norm_layer = Norm() + +class SparseAutoencoder(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(SparseAutoencoder, self).__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, output_dim), + ) + self.norm = Norm() + self.decoder = nn.Sequential( + nn.Linear(output_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, input_dim), + ) + self.last_run = None + + def forward(self, x): + self.last_run = { + "input": x + } + x = self.encoder(x) + x = self.norm(x) + self.last_run["sparse"] = x + x = self.decoder(x) + x = self.norm(x) + self.last_run["output"] = x + return x + + +class MLPR(nn.Module): # MLP with reshaping + def __init__( + self, + in_dim, + in_channels, + out_dim, + out_channels, + use_residual=True + ): + super().__init__() + if use_residual: + assert in_dim == out_dim + # dont normalize if using conv + self.layer_norm = nn.LayerNorm(in_dim) + + self.fc1 = nn.Linear(in_dim, out_dim) + self.act_fn = nn.GELU() + self.conv1 = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x): + residual = x + x = self.layer_norm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.conv1(x) + return x + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class VisionDirectAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] < batch_size: + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomFluxVDAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.block_idx = block_idx + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # begin ip adapter + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds + block_scaler = self.adapter_ref().block_scaler + if block_scaler is not None: + # add 1 to block scaler so we can decay its weight to 1.0 + block_scaler = block_scaler[self.block_idx] + 1.0 + + if adapter_hidden_states.shape[0] < batch_size: + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + # scale to block scaler + if block_scaler is not None: + orig_dtype = vd_hidden_states.dtype + if block_scaler.dtype != vd_hidden_states.dtype: + vd_hidden_states = vd_hidden_states.to(block_scaler.dtype) + vd_hidden_states = vd_hidden_states * block_scaler + if block_scaler.dtype != orig_dtype: + vd_hidden_states = vd_hidden_states.to(orig_dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + +class VisionDirectAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + vision_model: Union[CLIPVisionModelWithProjection], + ): + super(VisionDirectAdapter, self).__init__() + is_pixart = sd.is_pixart + is_flux = sd.is_flux + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.config: AdapterConfig = adapter.config + self.vision_model_ref: weakref.ref = weakref.ref(vision_model) + self.resampler = None + is_pixtral = self.config.image_encoder_arch == "pixtral" + + if adapter.config.clip_layer == "image_embeds": + if isinstance(vision_model, SiglipVisionModel): + self.token_size = vision_model.config.hidden_size + else: + self.token_size = vision_model.config.projection_dim + else: + self.token_size = vision_model.config.hidden_size + + self.mid_size = self.token_size + + if self.config.conv_pooling and self.config.conv_pooling_stacks > 1: + self.mid_size = self.mid_size * self.config.conv_pooling_stacks + + # if pixtral, use cross attn dim for more sparse representation if only doing double transformers + if is_pixtral and self.config.flux_only_double: + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + self.mid_size = hidden_size + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + elif is_flux: + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn") + + if not self.config.flux_only_double: + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + current_idx = 0 + + for name in attn_processor_keys: + if is_flux: + cross_attention_dim = None + else: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer") or name.startswith("single_transformer"): + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None and not is_flux: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux: + # is quantized + + to_k_adapter = torch.randn(hidden_size, hidden_size) * 0.01 + to_v_adapter = torch.randn(hidden_size, hidden_size) * 0.01 + to_k_adapter = to_k_adapter.to(self.sd_ref().torch_dtype) + to_v_adapter = to_v_adapter.to(self.sd_ref().torch_dtype) + else: + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.mid_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + elif to_k_adapter.shape[1] > self.mid_size: + to_k_adapter = to_k_adapter[:, :self.mid_size] + to_v_adapter = to_v_adapter[:, :self.mid_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.mid_size] + # to_v_bias = to_v_bias[:self.mid_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias + + weights = { + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, + } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias\ + + if is_flux: + attn_procs[name] = CustomFluxVDAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.mid_size, + has_bias=False, + block_idx=current_idx + ) + else: + attn_procs[name] = VisionDirectAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.mid_size, + has_bias=False, + ) + current_idx += 1 + attn_procs[name].load_state_dict(weights) + + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + elif self.sd_ref().is_flux: + # we have to set them ourselves + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + if not self.config.flux_only_double: + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + + if not self.config.flux_only_double: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) + else: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + ) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + num_modules = len(self.adapter_modules) + if self.config.train_scaler: + self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to( + dtype=torch.float32, + device=self.sd_ref().device_torch + )) + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + self.block_scaler.requires_grad = True + else: + self.block_scaler = None + + self.pool = None + + if self.config.num_tokens is not None: + # image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + # max_seq_len = 257 + # if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # # clip + # max_seq_len = int( + # image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + # self.resampler = MLPR( + # in_dim=self.token_size, + # in_channels=max_seq_len, + # out_dim=self.mid_size, + # out_channels=self.config.num_tokens, + # ) + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False), + Norm(), + ) + + elif self.config.image_encoder_arch == "pixtral": + self.resampler = VisionLanguageAdapter( + in_dim=self.token_size, + out_dim=self.mid_size, + ) + + self.sparse_autoencoder = None + if self.config.conv_pooling: + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False), + Norm(), + ) + if self.config.sparse_autoencoder_dim is not None: + hidden_dim = self.token_size * 2 + if hidden_dim > self.config.sparse_autoencoder_dim: + hidden_dim = self.config.sparse_autoencoder_dim + self.sparse_autoencoder = SparseAutoencoder( + input_dim=self.token_size, + hidden_dim=hidden_dim, + output_dim=self.config.sparse_autoencoder_dim + ) + + if self.config.clip_layer == "image_embeds": + self.proj = nn.Linear(self.token_size, self.token_size) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.config.train_scaler: + # only return the block scaler + if destination is None: + destination = OrderedDict() + destination[prefix + 'block_scaler'] = self.block_scaler + return destination + return super().state_dict(destination, prefix, keep_vars) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, input): + # block scaler keeps moving dtypes. make sure it is float32 here + # todo remove this when we have a real solution + + if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + # if doing image_embeds, normalize here + if self.config.clip_layer == "image_embeds": + input = norm_layer(input) + input = self.proj(input) + if self.resampler is not None: + input = self.resampler(input) + if self.pool is not None: + input = self.pool(input) + if self.config.conv_pooling_stacks > 1: + input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2) + if self.sparse_autoencoder is not None: + input = self.sparse_autoencoder(input) + return input + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + if self.block_scaler is not None: + if self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + return self + + def post_weight_update(self): + # force block scaler to be mean of 1 + pass diff --git a/toolkit/models/zipper_resampler.py b/toolkit/models/zipper_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..35f018b09bd49e802a9a26c225890412706bb1c8 --- /dev/null +++ b/toolkit/models/zipper_resampler.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + + +class ContextualAlphaMask(nn.Module): + def __init__( + self, + dim: int = 768, + ): + super(ContextualAlphaMask, self).__init__() + self.dim = dim + + half_dim = dim // 2 + quarter_dim = dim // 4 + + self.fc1 = nn.Linear(self.dim, self.dim) + self.fc2 = nn.Linear(self.dim, half_dim) + self.norm1 = nn.LayerNorm(half_dim) + self.fc3 = nn.Linear(half_dim, half_dim) + self.fc4 = nn.Linear(half_dim, quarter_dim) + self.norm2 = nn.LayerNorm(quarter_dim) + self.fc5 = nn.Linear(quarter_dim, quarter_dim) + self.fc6 = nn.Linear(quarter_dim, 1) + # set fc6 weights to near zero + self.fc6.weight.data.normal_(mean=0.0, std=0.0001) + self.act_fn = nn.GELU() + + def forward(self, x): + # x = (batch_size, 77, 768) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.norm1(x) + x = self.act_fn(x) + x = self.fc3(x) + x = self.act_fn(x) + x = self.fc4(x) + x = self.norm2(x) + x = self.act_fn(x) + x = self.fc5(x) + x = self.act_fn(x) + x = self.fc6(x) + x = torch.sigmoid(x) + return x + + +class ZipperModule(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + use_residual=False, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + self.use_residual = use_residual + + self.act_fn = nn.GELU() + self.layernorm = nn.LayerNorm(self.in_size) + + self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1) + # act + self.fc1 = nn.Linear(self.in_size, self.hidden_size) + # act + self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1) + # act + self.fc2 = nn.Linear(self.hidden_size, self.out_size) + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.conv1(x) + x = self.act_fn(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.conv2(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperResampler(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + num_blocks=1, + is_conv_input=False, + ): + super().__init__() + self.is_conv_input = is_conv_input + + module_list = [] + for i in range(num_blocks): + + this_in_size = in_size + this_in_tokens = in_tokens + this_out_size = out_size + this_out_tokens = out_tokens + this_hidden_size = hidden_size + this_hidden_tokens = hidden_tokens + use_residual = False + + # maintain middle sizes as hidden_size + if i == 0: # first block + this_in_size = in_size + this_in_tokens = in_tokens + if num_blocks == 1: + this_out_size = out_size + this_out_tokens = out_tokens + else: + this_out_size = hidden_size + this_out_tokens = hidden_tokens + elif i == num_blocks - 1: # last block + this_out_size = out_size + this_out_tokens = out_tokens + if num_blocks == 1: + this_in_size = in_size + this_in_tokens = in_tokens + else: + this_in_size = hidden_size + this_in_tokens = hidden_tokens + else: # middle blocks + this_out_size = hidden_size + this_out_tokens = hidden_tokens + this_in_size = hidden_size + this_in_tokens = hidden_tokens + use_residual = True + + module_list.append(ZipperModule( + in_size=this_in_size, + in_tokens=this_in_tokens, + out_size=this_out_size, + out_tokens=this_out_tokens, + hidden_size=this_hidden_size, + hidden_tokens=this_hidden_tokens, + use_residual=use_residual + )) + + self.blocks = nn.ModuleList(module_list) + + self.ctx_alpha = ContextualAlphaMask( + dim=out_size, + ) + + def forward(self, x): + if self.is_conv_input: + # flatten + x = x.view(x.size(0), x.size(1), -1) + # rearrange to (batch, tokens, size) + x = x.permute(0, 2, 1) + + for block in self.blocks: + x = block(x) + alpha = self.ctx_alpha(x) + return x * alpha diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..37f7987e2868e545627206de8e2e2654884c12a4 --- /dev/null +++ b/toolkit/network_mixins.py @@ -0,0 +1,727 @@ +import json +import os +from collections import OrderedDict +from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal + +import torch +from optimum.quanto import QTensor +from torch import nn +import weakref + +from tqdm import tqdm + +from toolkit.config_modules import NetworkConfig +from toolkit.lorm import extract_conv, extract_linear, count_parameters +from toolkit.metadata import add_model_hash_to_meta +from toolkit.paths import KEYMAPS_ROOT +from toolkit.saving import get_lora_keymap_from_model_keymap +from optimum.quanto import QBytesTensor + +if TYPE_CHECKING: + from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule + from toolkit.lora_special import LoRASpecialNetwork, LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.models.DoRA import DoRAModule + +Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] +Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule'] + +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear', + 'QLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +ExtractMode = Union[ + 'existing' + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + + +def broadcast_and_multiply(tensor, multiplier): + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - multiplier.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + multiplier = multiplier.unsqueeze(-1) + + try: + # Multiplying the broadcasted tensor with the output tensor + result = tensor * multiplier + except RuntimeError as e: + print(e) + print(tensor.size()) + print(multiplier.size()) + raise e + + return result + + +def add_bias(tensor, bias): + if bias is None: + return tensor + # add batch dim + bias = bias.unsqueeze(0) + bias = torch.cat([bias] * tensor.size(0), dim=0) + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - bias.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + bias = bias.unsqueeze(-1) + + # we may need to swap -1 for -2 + if bias.size(1) != tensor.size(1): + if len(bias.size()) == 3: + bias = bias.permute(0, 2, 1) + elif len(bias.size()) == 4: + bias = bias.permute(0, 3, 1, 2) + + # Multiplying the broadcasted tensor with the output tensor + try: + result = tensor + bias + except RuntimeError as e: + print(e) + print(tensor.size()) + print(bias.size()) + raise e + + return result + + +class ExtractableModuleMixin: + def extract_weight( + self: Module, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + device = self.lora_down.weight.device + weight_to_extract = self.org_module[0].weight + if extract_mode == "existing": + extract_mode = 'fixed' + extract_mode_param = self.lora_dim + + if isinstance(weight_to_extract, QBytesTensor): + weight_to_extract = weight_to_extract.dequantize() + + weight_to_extract = weight_to_extract.clone().detach().float() + + if self.org_module[0].__class__.__name__ in CONV_MODULES: + # do conv extraction + down_weight, up_weight, new_dim, diff = extract_conv( + weight=weight_to_extract, + mode=extract_mode, + mode_param=extract_mode_param, + device=device + ) + + elif self.org_module[0].__class__.__name__ in LINEAR_MODULES: + # do linear extraction + down_weight, up_weight, new_dim, diff = extract_linear( + weight=weight_to_extract, + mode=extract_mode, + mode_param=extract_mode_param, + device=device, + ) + else: + raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}") + + self.lora_dim = new_dim + + # inject weights into the param + self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach() + self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach() + + # copy bias if we have one and are using them + if self.org_module[0].bias is not None and self.lora_up.bias is not None: + self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach() + + # set up alphas + self.alpha = (self.alpha * 0) + down_weight.shape[0] + self.scale = self.alpha / self.lora_dim + + # assign them + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + # scaler is a parameter update the value with 1.0 + self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype) + + +class ToolkitModuleMixin: + def __init__( + self: Module, + *args, + network: Network, + **kwargs + ): + self.network_ref: weakref.ref = weakref.ref(network) + self.is_checkpointing = False + self._multiplier: Union[float, list, torch.Tensor] = None + + def _call_forward(self: Module, x): + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return 0.0 # added to original forward + + if hasattr(self, 'lora_mid') and self.lora_mid is not None: + lx = self.lora_mid(self.lora_down(x)) + else: + try: + lx = self.lora_down(x) + except RuntimeError as e: + print(f"Error in {self.__class__.__name__} lora_down") + print(e) + + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(lx) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.rank_dropout > 0 and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + return lx * scale + + def lorm_forward(self: Network, x, *args, **kwargs): + network: Network = self.network_ref() + if not network.is_active: + return self.org_forward(x, *args, **kwargs) + + orig_dtype = x.dtype + + if x.dtype != self.lora_down.weight.dtype: + x = x.to(self.lora_down.weight.dtype) + + if network.lorm_train_mode == 'local': + # we are going to predict input with both and do a loss on them + inputs = x.detach() + with torch.no_grad(): + # get the local prediction + target_pred = self.org_forward(inputs, *args, **kwargs).detach() + with torch.set_grad_enabled(True): + # make a prediction with the lorm + lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True))) + + local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float()) + # backpropr + local_loss.backward() + + network.module_losses.append(local_loss.detach()) + # return the original as we dont want our trainer to affect ones down the line + return target_pred + + else: + x = self.lora_up(self.lora_down(x)) + if x.dtype != orig_dtype: + x = x.to(orig_dtype) + + def forward(self: Module, x, *args, **kwargs): + skip = False + network: Network = self.network_ref() + if network.is_lorm: + # we are doing lorm + return self.lorm_forward(x, *args, **kwargs) + + # skip if not active + if not network.is_active: + skip = True + + # skip if is merged in + if network.is_merged_in: + skip = True + + # skip if multiplier is 0 + if network._multiplier == 0: + skip = True + + if skip: + # network is not active, avoid doing anything + return self.org_forward(x, *args, **kwargs) + + # if self.__class__.__name__ == "DoRAModule": + # # return dora forward + # return self.dora_forward(x, *args, **kwargs) + + org_forwarded = self.org_forward(x, *args, **kwargs) + + if isinstance(x, QTensor): + x = x.dequantize() + # always cast to float32 + lora_input = x.to(self.lora_down.weight.dtype) + lora_output = self._call_forward(lora_input) + multiplier = self.network_ref().torch_multiplier + + lora_output_batch_size = lora_output.size(0) + multiplier_batch_size = multiplier.size(0) + if lora_output_batch_size != multiplier_batch_size: + num_interleaves = lora_output_batch_size // multiplier_batch_size + # todo check if this is correct, do we just concat when doing cfg? + multiplier = multiplier.repeat_interleave(num_interleaves) + + scaled_lora_output = broadcast_and_multiply(lora_output, multiplier) + scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype) + + if self.__class__.__name__ == "DoRAModule": + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417 + # x = dropout(x) + # todo this wont match the dropout applied to the lora + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(x) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(x, p=self.dropout) + else: + lx = x + lora_weight = self.lora_up.weight @ self.lora_down.weight + # scale it here + # todo handle our batch split scalers for slider training. For now take the mean of them + scale = multiplier.mean() + scaled_lora_weight = lora_weight * scale + scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype) + + try: + x = org_forwarded + scaled_lora_output + except RuntimeError as e: + print(e) + print(org_forwarded.size()) + print(scaled_lora_output.size()) + raise e + return x + + def enable_gradient_checkpointing(self: Module): + self.is_checkpointing = True + + def disable_gradient_checkpointing(self: Module): + self.is_checkpointing = False + + @torch.no_grad() + def merge_out(self: Module, merge_out_weight=1.0): + # make sure it is positive + merge_out_weight = abs(merge_out_weight) + # merging out is just merging in the negative of the weight + self.merge_in(merge_weight=-merge_out_weight) + + @torch.no_grad() + def merge_in(self: Module, merge_weight=1.0): + if not self.can_merge_in: + return + # get up/down weight + up_weight = self.lora_up.weight.clone().float() + down_weight = self.lora_down.weight.clone().float() + + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + # todo find a way to merge in weights when doing quantized model + if 'weight._data' in org_sd: + # quantized weight + return + + weight_key = "weight" + if 'weight._data' in org_sd: + # quantized weight + weight_key = "weight._data" + + orig_dtype = org_sd[weight_key].dtype + weight = org_sd[weight_key].float() + + multiplier = merge_weight + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + multiplier * conved * scale + + # set weight to org_module + org_sd[weight_key] = weight.to(orig_dtype) + self.org_module[0].load_state_dict(org_sd) + + def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): + # LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and + # outputs the same. It is basically a LoRA but with the original module removed + + # if a state dict is passed, use those weights instead of extracting + # todo load from state dict + network: Network = self.network_ref() + lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + self.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + + +class ToolkitNetworkMixin: + def __init__( + self: Network, + *args, + train_text_encoder: Optional[bool] = True, + train_unet: Optional[bool] = True, + is_sdxl=False, + is_v2=False, + is_ssd=False, + is_vega=False, + network_config: Optional[NetworkConfig] = None, + is_lorm=False, + **kwargs + ): + self.train_text_encoder = train_text_encoder + self.train_unet = train_unet + self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self.is_sdxl = is_sdxl + self.is_ssd = is_ssd + self.is_vega = is_vega + self.is_v2 = is_v2 + self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega + self.is_merged_in = False + self.is_lorm = is_lorm + self.network_config: NetworkConfig = network_config + self.module_losses: List[torch.Tensor] = [] + self.lorm_train_mode: Literal['local', None] = None + self.can_merge_in = not is_lorm + + def get_keymap(self: Network, force_weight_mapping=False): + use_weight_mapping = False + + if self.is_ssd: + keymap_tail = 'ssd' + use_weight_mapping = True + elif self.is_vega: + keymap_tail = 'vega' + use_weight_mapping = True + elif self.is_sdxl: + keymap_tail = 'sdxl' + elif self.is_v2: + keymap_tail = 'sd2' + else: + keymap_tail = 'sd1' + # todo double check this + # use_weight_mapping = True + + if force_weight_mapping: + use_weight_mapping = True + + # load keymap + keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" + if use_weight_mapping: + keymap_name = f"stable_diffusion_{keymap_tail}.json" + + keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name) + + keymap = None + # check if file exists + if os.path.exists(keymap_path): + with open(keymap_path, 'r') as f: + keymap = json.load(f)['ldm_diffusers_keymap'] + + if use_weight_mapping and keymap is not None: + # get keymap from weights + keymap = get_lora_keymap_from_model_keymap(keymap) + + # upgrade keymaps for DoRA + if self.network_type.lower() == 'dora': + if keymap is not None: + new_keymap = {} + for ldm_key, diffusers_key in keymap.items(): + ldm_key = ldm_key.replace('.alpha', '.magnitude') + # ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') + # ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') + + diffusers_key = diffusers_key.replace('.alpha', '.magnitude') + # diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') + # diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') + + new_keymap[ldm_key] = diffusers_key + + keymap = new_keymap + + return keymap + + def save_weights( + self: Network, + file, dtype=torch.float16, + metadata=None, + extra_state_dict: Optional[OrderedDict] = None + ): + keymap = self.get_keymap() + + save_keymap = {} + if keymap is not None: + for ldm_key, diffusers_key in keymap.items(): + # invert them + save_keymap[diffusers_key] = ldm_key + + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + save_dict = OrderedDict() + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_key = save_keymap[key] if key in save_keymap else key + save_dict[save_key] = v + del state_dict[key] + + if extra_state_dict is not None: + # add extra items to state dict + for key in list(extra_state_dict.keys()): + v = extra_state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_dict[key] = v + + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + + new_save_dict = {} + for key, value in save_dict.items(): + if key.endswith('.alpha'): + continue + new_key = key + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + # replace all $$ with . + new_key = new_key.replace('$$', '.') + new_save_dict[new_key] = value + + save_dict = new_save_dict + + if metadata is None: + metadata = OrderedDict() + metadata = add_model_hash_to_meta(state_dict, metadata) + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + save_file(save_dict, file, metadata) + else: + torch.save(save_dict, file) + + def load_weights(self: Network, file, force_weight_mapping=False): + # allows us to save and load to and from ldm weights + keymap = self.get_keymap(force_weight_mapping) + keymap = {} if keymap is None else keymap + + if isinstance(file, str): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + else: + # probably a state dict + weights_sd = file + + load_sd = OrderedDict() + for key, value in weights_sd.items(): + load_key = keymap[key] if key in keymap else key + # replace old double __ with single _ + if self.is_pixart: + load_key = load_key.replace('__', '_') + + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + if load_key.endswith('.alpha'): + continue + load_key = load_key.replace('lora_A', 'lora_down') + load_key = load_key.replace('lora_B', 'lora_up') + # replace all . with $$ + load_key = load_key.replace('.', '$$') + load_key = load_key.replace('$$lora_down$$', '.lora_down.') + load_key = load_key.replace('$$lora_up$$', '.lora_up.') + + load_sd[load_key] = value + + # extract extra items from state dict + current_state_dict = self.state_dict() + extra_dict = OrderedDict() + to_delete = [] + for key in list(load_sd.keys()): + if key not in current_state_dict: + extra_dict[key] = load_sd[key] + to_delete.append(key) + for key in to_delete: + del load_sd[key] + + print(f"Missing keys: {to_delete}") + if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not ( + len(to_delete) == 1 and 'emb_params' in to_delete): + print(" Attempting to load with forced keymap") + return self.load_weights(file, force_weight_mapping=True) + + info = self.load_state_dict(load_sd, False) + if len(extra_dict.keys()) == 0: + extra_dict = None + return extra_dict + + @torch.no_grad() + def _update_torch_multiplier(self: Network): + # builds a tensor for fast usage in the forward pass of the network modules + # without having to set it in every single module every time it changes + multiplier = self._multiplier + # get first module + first_module = self.get_all_modules()[0] + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype + with torch.no_grad(): + tensor_multiplier = None + if isinstance(multiplier, int) or isinstance(multiplier, float): + tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) + elif isinstance(multiplier, list): + tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype) + elif isinstance(multiplier, torch.Tensor): + tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) + + self.torch_multiplier = tensor_multiplier.clone().detach() + + @property + def multiplier(self) -> Union[float, List[float], List[List[float]]]: + return self._multiplier + + @multiplier.setter + def multiplier(self, value: Union[float, List[float], List[List[float]]]): + # it takes time to update all the multipliers, so we only do it if the value has changed + if self._multiplier == value: + return + # if we are setting a single value but have a list, keep the list if every item is the same as value + self._multiplier = value + self._update_torch_multiplier() + + # called when the context manager is entered + # ie: with network: + def __enter__(self: Network): + self.is_active = True + + def __exit__(self: Network, exc_type, exc_value, tb): + self.is_active = False + + def force_to(self: Network, device, dtype): + self.to(device, dtype) + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + for lora in loras: + lora.to(device, dtype) + + def get_all_modules(self: Network) -> List[Module]: + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + return loras + + def _update_checkpointing(self: Network): + for module in self.get_all_modules(): + if self.is_checkpointing: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() + + def enable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = True + self._update_checkpointing() + + def disable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = False + self._update_checkpointing() + + def merge_in(self, merge_weight=1.0): + if self.network_type.lower() == 'dora': + return + self.is_merged_in = True + for module in self.get_all_modules(): + module.merge_in(merge_weight) + + def merge_out(self: Network, merge_weight=1.0): + if not self.is_merged_in: + return + self.is_merged_in = False + for module in self.get_all_modules(): + module.merge_out(merge_weight) + + def extract_weight( + self: Network, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + if extract_mode_param is None: + raise ValueError("extract_mode_param must be set") + for module in tqdm(self.get_all_modules(), desc="Extracting weights"): + module.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + + def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None): + for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"): + module.setup_lorm(state_dict=state_dict) + + def calculate_lorem_parameter_reduction(self): + params_reduced = 0 + for module in self.get_all_modules(): + num_orig_module_params = count_parameters(module.org_module[0]) + num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up) + params_reduced += (num_orig_module_params - num_lorem_params) + + return params_reduced diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a258ff3a4c37c576d2bf51787cff00af0bafa4 --- /dev/null +++ b/toolkit/optimizer.py @@ -0,0 +1,103 @@ +import torch +from transformers import Adafactor, AdamW + + +def get_optimizer( + params, + optimizer_type='adam', + learning_rate=1e-6, + optimizer_params=None +): + if optimizer_params is None: + optimizer_params = {} + lower_type = optimizer_type.lower() + if lower_type.startswith("dadaptation"): + # dadaptation optimizer does not use standard learning rate. 1 is the default value + import dadaptation + print("Using DAdaptAdam optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + if lower_type.endswith('lion'): + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + elif lower_type.endswith('adam'): + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + elif lower_type == 'dadaptation': + # backwards compatibility + optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) + # warn user that dadaptation is deprecated + print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") + elif lower_type.startswith("prodigy8bit"): + from toolkit.optimizers.prodigy_8bit import Prodigy8bit + print("Using Prodigy optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + + print(f"Using lr {use_lr}") + # let net be the neural network you want to train + # you can choose weight decay value based on your problem, 0 by default + optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params) + elif lower_type.startswith("prodigy"): + from prodigyopt import Prodigy + + print("Using Prodigy optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + + print(f"Using lr {use_lr}") + # let net be the neural network you want to train + # you can choose weight decay value based on your problem, 0 by default + optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) + elif lower_type == "adam8": + from toolkit.optimizers.adam8bit import Adam8bit + + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "adamw8": + from toolkit.optimizers.adam8bit import Adam8bit + + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params) + elif lower_type.endswith("8bit"): + import bitsandbytes + + if lower_type == "adam8bit": + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + if lower_type == "ademamix8bit": + return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "adamw8bit": + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "lion8bit": + return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') + elif lower_type == 'adam': + optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'adamw': + optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'lion': + try: + from lion_pytorch import Lion + return Lion(params, lr=learning_rate, **optimizer_params) + except ImportError: + raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") + elif lower_type == 'adagrad': + optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'adafactor': + from toolkit.optimizers.adafactor import Adafactor + if 'relative_step' not in optimizer_params: + optimizer_params['relative_step'] = False + if 'scale_parameter' not in optimizer_params: + optimizer_params['scale_parameter'] = False + if 'warmup_init' not in optimizer_params: + optimizer_params['warmup_init'] = False + optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'automagic': + from toolkit.optimizers.automagic import Automagic + optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') + return optimizer diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..00cf06ee4ab7afbd201414b2ca2a302ffe539c9f --- /dev/null +++ b/toolkit/optimizers/adafactor.py @@ -0,0 +1,359 @@ +import math +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor +import random + + +class Adafactor(torch.optim.Optimizer): + """ + Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply. + Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply. + + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, + ): + if lr is not None and relative_step: + raise ValueError( + "Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError( + "`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * \ + param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # adafactor manages its own lr + def get_learning_rates(self): + lrs = [ + self._get_lr(group, self.state[group["params"][0]]) + for group in self.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None or not p.requires_grad: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Adafactor does not support sparse gradients.") + + # if p has atts _scale then it is quantized. We need to divide the grad by the scale + # if hasattr(p, "_scale"): + # grad = grad / p._scale + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options( + group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( + grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( + grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_( + update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss diff --git a/toolkit/optimizers/adam8bit.py b/toolkit/optimizers/adam8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fc976bf456c6f93e26776a614e150450e1875e --- /dev/null +++ b/toolkit/optimizers/adam8bit.py @@ -0,0 +1,162 @@ +import math +import torch +from torch.optim import Optimizer +from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation + +class Adam8bit(Optimizer): + """ + Implements Adam optimizer with 8-bit state storage and stochastic rounding. + + Arguments: + params (iterable): Iterable of parameters to optimize or dicts defining parameter groups + lr (float): Learning rate (default: 1e-3) + betas (tuple): Coefficients for computing running averages of gradient and its square (default: (0.9, 0.999)) + eps (float): Term added to denominator to improve numerical stability (default: 1e-8) + weight_decay (float): Weight decay coefficient (default: 0) + decouple (bool): Use AdamW style decoupled weight decay (default: True) + """ + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, decouple=True): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + decouple=decouple) + super(Adam8bit, self).__init__(params, defaults) + + self.is_stochastic_rounding_accumulation = False + + # Setup stochastic grad accumulation hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # Copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + # Call pre step + self.step_hook() + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + lr = group['lr'] + decay = group['weight_decay'] + decouple = group['decouple'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p_fp32.data, alpha=decay) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p_fp32.data.mul_(1 - lr * decay) + + # Bias correction + step_size = lr / bias_correction1 + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Take step + p_fp32.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Update state with stochastic rounding + state['exp_avg'] = Auto8bitTensor(exp_avg) + state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq) + + # Apply stochastic rounding to parameters + copy_stochastic(p.data, p_fp32.data) + + return loss + + def state_dict(self): + """Returns the state of the optimizer as a dict.""" + state_dict = super().state_dict() + + # Convert Auto8bitTensor objects to regular state dicts + for param_id, param_state in state_dict['state'].items(): + for key, value in param_state.items(): + if isinstance(value, Auto8bitTensor): + param_state[key] = { + '_type': 'Auto8bitTensor', + 'state': value.state_dict() + } + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the optimizer state.""" + # First, load the basic state + super().load_state_dict(state_dict) + + # Then convert any Auto8bitTensor states back to objects + for param_id, param_state in self.state.items(): + for key, value in param_state.items(): + if isinstance(value, dict) and value.get('_type') == 'Auto8bitTensor': + param_state[key] = Auto8bitTensor(value['state']) + diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7355f168aa03d77a01124c5cb42f1ea625a61a --- /dev/null +++ b/toolkit/optimizers/automagic.py @@ -0,0 +1,335 @@ +from collections import OrderedDict +import math +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor +import random + + +class Automagic(torch.optim.Optimizer): + def __init__( + self, + params, + lr=None, + min_lr=1e-7, + max_lr=1e-3, + lr_pump_scale=1.1, + lr_dump_scale=0.85, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + weight_decay=0.0, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, + ): + self.lr = lr + self.min_lr = min_lr + self.max_lr = max_lr + self.lr_pump_scale = lr_pump_scale + self.lr_dump_scale = lr_dump_scale + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int( + self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) + + @staticmethod + def _get_lr(param_group, param_state): + if 'avg_lr' in param_state: + lr = param_state["avg_lr"] + else: + lr = 0.0 + return lr + + def _get_group_lr(self, group): + group_lrs = [] + for p in group["params"]: + group_lrs.append(self._get_lr(group, self.state[p])) + # return avg + if len(group_lrs) == 0: + return self.lr + return sum(group_lrs) / len(group_lrs) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # adafactor manages its own lr + def get_learning_rates(self): + + lrs = [ + self._get_group_lr(group) + for group in self.param_groups + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + def get_avg_learning_rate(self): + lrs = self.get_learning_rates() + return sum(lrs) / len(lrs) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None or not p.requires_grad: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Automagic does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored = len(grad_shape) >= 2 + # State Initialization + if len(state) == 0: + self.initialize_state(p) + else: + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( + grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( + grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + # lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + + # calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease + # update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr, + # for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter. + # to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum. + + # not doing it this way anymore + # update.mul_(lr) + + # Get signs of current last update and updates + last_polarity = state['last_polarity'] + current_polarity = (update > 0).to(torch.bool) + sign_agreement = torch.where( + last_polarity == current_polarity, 1, -1) + state['last_polarity'] = current_polarity + + lr_mask = state['lr_mask'].to(torch.float32) + + # Update learning rate mask based on sign agreement + new_lr = torch.where( + sign_agreement > 0, + lr_mask * self.lr_pump_scale, # Increase lr + lr_mask * self.lr_dump_scale # Decrease lr + ) + + # Clip learning rates to bounds + new_lr = torch.clamp( + new_lr, + min=self.min_lr, + max=self.max_lr + ) + + # Apply the learning rate mask to the update + update.mul_(new_lr) + + state['lr_mask'] = Auto8bitTensor(new_lr) + state['avg_lr'] = torch.mean(new_lr) + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=(-group["weight_decay"] * new_lr)) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss + + def initialize_state(self, p): + state = self.state[p] + state["step"] = 0 + + # store the lr mask + if 'lr_mask' not in state: + state['lr_mask'] = Auto8bitTensor(torch.ones( + p.shape).to(p.device, dtype=torch.float32) * self.lr + ) + state['avg_lr'] = torch.mean( + state['lr_mask'].to(torch.float32)) + if 'last_polarity' not in state: + state['last_polarity'] = torch.zeros( + p.shape, dtype=torch.bool, device=p.device) + + factored = len(p.shape) >= 2 + if factored: + state["exp_avg_sq_row"] = torch.zeros( + p.shape[:-1]).to(p) + state["exp_avg_sq_col"] = torch.zeros( + p.shape[:-2] + p.shape[-1:]).to(p) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + + state["RMS"] = 0 + + # override the state_dict to save the lr_mask + def state_dict(self, *args, **kwargs): + orig_state_dict = super().state_dict(*args, **kwargs) + # convert the state to quantized tensor to scale and quantized + new_sace_state = {} + for p, state in orig_state_dict['state'].items(): + save_state = {k: v for k, v in state.items() if k != 'lr_mask'} + save_state['lr_mask'] = state['lr_mask'].state_dict() + new_sace_state[p] = save_state + + orig_state_dict['state'] = new_sace_state + + return orig_state_dict + + def load_state_dict(self, state_dict, strict=True): + # load the lr_mask from the state_dict + idx = 0 + for group in self.param_groups: + for p in group['params']: + self.initialize_state(p) + state = self.state[p] + m = state_dict['state'][idx]['lr_mask'] + sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale'] + state['lr_mask'] = Auto8bitTensor(sd_mask) + del state_dict['state'][idx]['lr_mask'] + idx += 1 + super().load_state_dict(state_dict) diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..67991f21a5da53f22425dd8519eda3fc159bb93c --- /dev/null +++ b/toolkit/optimizers/optimizer_utils.py @@ -0,0 +1,256 @@ +import torch +from torch import Tensor +from typing import Optional +from optimum.quanto import QBytesTensor + + +def compute_scale_for_dtype(tensor, dtype): + """ + Compute appropriate scale for the given tensor and target dtype. + + Args: + tensor: Input tensor to be quantized + dtype: Target dtype for quantization + Returns: + Appropriate scale factor for the quantization + """ + if dtype == torch.int8: + abs_max = torch.max(torch.abs(tensor)) + return abs_max / 127.0 if abs_max > 0 else 1.0 + elif dtype == torch.uint8: + max_val = torch.max(tensor) + min_val = torch.min(tensor) + range_val = max_val - min_val + return range_val / 255.0 if range_val > 0 else 1.0 + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we typically want to preserve the magnitude of the values + # while fitting within the representable range of the format + abs_max = torch.max(torch.abs(tensor)) + if dtype == torch.float8_e4m3fn: + # e4m3fn has range [-448, 448] with no infinities + max_representable = 448.0 + else: # torch.float8_e5m2 + # e5m2 has range [-57344, 57344] with infinities + max_representable = 57344.0 + + return abs_max / max_representable if abs_max > 0 else 1.0 + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + +def quantize_tensor(tensor, dtype): + """ + Quantize a floating-point tensor to the target dtype with appropriate scaling. + + Args: + tensor: Input tensor (float) + dtype: Target dtype for quantization + Returns: + quantized_data: Quantized tensor + scale: Scale factor used + """ + scale = compute_scale_for_dtype(tensor, dtype) + + if dtype == torch.int8: + quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype) + elif dtype == torch.uint8: + quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype) + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we scale and then cast directly to the target type + # The casting operation will handle the appropriate rounding + scaled_tensor = tensor / scale + quantized_data = scaled_tensor.to(dtype) + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + + return quantized_data, scale + + +def update_parameter(target, result_float): + """ + Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases + with proper rescaling for quantized tensors. + + Args: + target: The parameter to update (either torch.Tensor or QBytesTensor) + result_float: The new values to assign (torch.Tensor) + """ + if isinstance(target, QBytesTensor): + # Get the target dtype from the existing quantized tensor + target_dtype = target._data.dtype + + # Handle device placement + device = target._data.device + result_float = result_float.to(device) + + # Compute new quantized values and scale + quantized_data, new_scale = quantize_tensor(result_float, target_dtype) + + # Update the internal tensors with newly computed values + target._data.copy_(quantized_data) + target._scale.copy_(new_scale) + else: + # Regular tensor update + target.copy_(result_float) + + +def get_format_params(dtype: torch.dtype) -> tuple[int, int]: + """ + Returns (mantissa_bits, total_bits) for each format. + mantissa_bits excludes the implicit leading 1. + """ + if dtype == torch.float32: + return 23, 32 + elif dtype == torch.bfloat16: + return 7, 16 + elif dtype == torch.float16: + return 10, 16 + elif dtype == torch.float8_e4m3fn: + return 3, 8 + elif dtype == torch.float8_e5m2: + return 2, 8 + elif dtype == torch.int8: + return 0, 8 # Int8 doesn't have mantissa bits + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def copy_stochastic( + target: torch.Tensor, + source: torch.Tensor, + eps: Optional[float] = None +) -> None: + """ + Performs stochastic rounding from source tensor to target tensor. + + Args: + target: Destination tensor (determines the target format) + source: Source tensor (typically float32) + eps: Optional minimum value for stochastic rounding (for numerical stability) + """ + with torch.no_grad(): + # If target is float32, just copy directly + if target.dtype == torch.float32: + target.copy_(source) + return + + # Special handling for int8 + if target.dtype == torch.int8: + # Scale the source values to utilize the full int8 range + scaled = source * 127.0 # Scale to [-127, 127] + + # Add random noise for stochastic rounding + noise = torch.rand_like(scaled) - 0.5 + rounded = torch.round(scaled + noise) + + # Clamp to int8 range + clamped = torch.clamp(rounded, -127, 127) + target.copy_(clamped.to(torch.int8)) + return + + mantissa_bits, _ = get_format_params(target.dtype) + + # Convert source to int32 view + source_int = source.view(dtype=torch.int32) + + # Calculate number of bits to round + bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits + + # Create random integers for stochastic rounding + rand = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << bits_to_round), + ) + + # Add random values to the bits that will be rounded off + result = source_int.clone() + result.add_(rand) + + # Mask to keep only the bits we want + # Create mask with 1s in positions we want to keep + mask = (-1) << bits_to_round + result.bitwise_and_(mask) + + # Handle minimum value threshold if specified + if eps is not None: + eps_int = torch.tensor( + eps, dtype=torch.float32).view(dtype=torch.int32) + zero_mask = (result.abs() < eps_int) + result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int + + # Convert back to float32 view + result_float = result.view(dtype=torch.float32) + + # Special handling for float8 formats + if target.dtype == torch.float8_e4m3fn: + result_float.clamp_(-448.0, 448.0) + elif target.dtype == torch.float8_e5m2: + result_float.clamp_(-57344.0, 57344.0) + + # Copy the result to the target tensor + update_parameter(target, result_float) + # target.copy_(result_float) + del result, rand, source_int + + +class Auto8bitTensor: + def __init__(self, data: Tensor, *args, **kwargs): + if isinstance(data, dict): # Add constructor from state dict + self._load_from_state_dict(data) + else: + abs_max = data.abs().max().item() + scale = abs_max / 127.0 if abs_max > 0 else 1.0 + + self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) + self.scale = scale + self.orig_dtype = data.dtype + + def dequantize(self) -> Tensor: + return self.quantized.to(dtype=torch.float32) * self.scale + + def to(self, *args, **kwargs): + # Handle the dtype argument whether it's positional or keyword + dtype = None + if args and isinstance(args[0], torch.dtype): + dtype = args[0] + args = args[1:] + elif 'dtype' in kwargs: + dtype = kwargs['dtype'] + del kwargs['dtype'] + + if dtype is not None: + # First dequantize then convert to requested dtype + return self.dequantize().to(dtype=dtype, *args, **kwargs) + + # If no dtype specified, just pass through to parent + return self.dequantize().to(*args, **kwargs) + + def state_dict(self): + """Returns a dictionary containing the current state of the tensor.""" + return { + 'quantized': self.quantized, + 'scale': self.scale, + 'orig_dtype': self.orig_dtype + } + + def _load_from_state_dict(self, state_dict): + """Loads the tensor state from a state dictionary.""" + self.quantized = state_dict['quantized'] + self.scale = state_dict['scale'] + self.orig_dtype = state_dict['orig_dtype'] + + def __str__(self): + return f"Auto8bitTensor({self.dequantize()})" + + +def stochastic_grad_accummulation(param): + if hasattr(param, "_accum_grad"): + grad_fp32 = param._accum_grad.clone().to(torch.float32) + grad_fp32.add_(param.grad.to(torch.float32)) + copy_stochastic(param._accum_grad, grad_fp32) + del grad_fp32 + del param.grad + else: + param._accum_grad = param.grad.clone() + del param.grad diff --git a/toolkit/optimizers/prodigy_8bit.py b/toolkit/optimizers/prodigy_8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7f09149583da67d8f4fbaea6051b0b6694e467 --- /dev/null +++ b/toolkit/optimizers/prodigy_8bit.py @@ -0,0 +1,286 @@ +import math +import torch +import torch.distributed as dist +from torch.optim import Optimizer +from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation + + +class Prodigy8bit(Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Handles stochastic rounding for various precisions as well as stochastic gradient accumulation. + Stores state in 8bit for memory savings. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__(self, params, lr=1.0, + betas=(0.9, 0.999), beta3=None, + eps=1e-8, weight_decay=0, decouple=True, + use_bias_correction=False, safeguard_warmup=False, + d0=1e-6, d_coef=1.0, growth_rate=float('inf'), + fsdp_in_use=False): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict(lr=lr, betas=betas, beta3=beta3, + eps=eps, weight_decay=weight_decay, + d=d0, d0=d0, d_max=d0, + d_numerator=0.0, d_coef=d_coef, + k=0, growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use) + self.d0 = d0 + super(Prodigy8bit, self).__init__(params, defaults) + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + # call pre step + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group['use_bias_correction'] + beta1, beta2 = group['betas'] + beta3 = group['beta3'] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group['k'] + + d = group['d'] + d_max = group['d_max'] + d_coef = group['d_coef'] + lr = max(group['lr'] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2**(k+1))**0.5) / (1 - beta1**(k+1)) + else: + bias_correction = 1 + + dlr = d*lr*bias_correction + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + fsdp_in_use = group['fsdp_in_use'] + + d_numerator = group['d_numerator'] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + group_lr = group['lr'] + d0 = group['d0'] + safeguard_warmup = group['safeguard_warmup'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0") + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p_fp32.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + state['p0'] = Auto8bitTensor(p_fp32.detach().clone()) + # Exponential moving average of gradient values + state['exp_avg'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + s = state['s'].to(torch.float32) + p0 = state['p0'].to(torch.float32) + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), + (p0.data - p_fp32.data).flatten()).item() + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1-beta2)) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + # update state with stochastic rounding + state['exp_avg'] = Auto8bitTensor(exp_avg) + state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq) + state['s'] = Auto8bitTensor(s) + state['p0'] = Auto8bitTensor(p0) + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group['d0']: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group['d_numerator'] = global_d_numerator + group['d_denom'] = global_d_denom + group['d'] = d + group['d_max'] = d_max + group['d_hat'] = d_hat + + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + state = self.state[p] + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p_fp32.data.add_(p_fp32.data, alpha=-decay * dlr) + + # Take step + p_fp32.data.addcdiv_(exp_avg, denom, value=-dlr) + # apply stochastic rounding + copy_stochastic(p.data, p_fp32.data) + + group['k'] = k + 1 + + return loss diff --git a/toolkit/orig_configs/sd_xl_refiner.yaml b/toolkit/orig_configs/sd_xl_refiner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cab5fe283d77bf86e0f29e99f3ed0d3c7d9c752f --- /dev/null +++ b/toolkit/orig_configs/sd_xl_refiner.yaml @@ -0,0 +1,91 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2560 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 384 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 4 + context_dim: [1280, 1280, 1280, 1280] # 1280 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + legacy: False + freeze: True + layer: penultimate + always_return_pooled: True + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: aesthetic_score + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by one + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/toolkit/paths.py b/toolkit/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..b926c82f13d36790b9b17de7355b3a0d6e1abcbd --- /dev/null +++ b/toolkit/paths.py @@ -0,0 +1,22 @@ +import os + +TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') +SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts") +REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories") +KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") +ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") +DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") + +# check if ENV variable is set +if 'MODELS_PATH' in os.environ: + MODELS_PATH = os.environ['MODELS_PATH'] +else: + MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models") + + +def get_path(path): + # we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root + if not os.path.isabs(path): + path = os.path.join(TOOLKIT_ROOT, path) + return path diff --git a/toolkit/photomaker.py b/toolkit/photomaker.py new file mode 100644 index 0000000000000000000000000000000000000000..8037969507854129cf342d8b3fae7a6d1ff7581e --- /dev/null +++ b/toolkit/photomaker.py @@ -0,0 +1,144 @@ +# Merge image encoder and fuse module to create an ID Encoder +# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding + +import torch +import torch.nn as nn +from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection +from transformers.models.clip.configuration_clip import CLIPVisionConfig +from transformers import PretrainedConfig + +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768 +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) + self.layer_norm = nn.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[1] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module( + prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + + def forward(self, id_pixel_values, do_projection2=True, output_full=False): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + # last_hidden_state, 1, 257, 1024 + vision_output = self.vision_model(id_pixel_values, output_hidden_states=True) + shared_id_embeds = vision_output[1] + id_embeds = self.visual_projection(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + + if do_projection2: + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + + if output_full: + return id_embeds, vision_output + return id_embeds + + + +if __name__ == "__main__": + PhotoMakerIDEncoder() \ No newline at end of file diff --git a/toolkit/photomaker_pipeline.py b/toolkit/photomaker_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d6437b648e91e5d4e70abbcf0995d76dc1b00f81 --- /dev/null +++ b/toolkit/photomaker_pipeline.py @@ -0,0 +1,491 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from collections import OrderedDict +import os +import PIL +import numpy as np + +import torch +from torchvision import transforms as T + +from safetensors import safe_open +from huggingface_hub.utils import validate_hf_hub_args +from transformers import CLIPImageProcessor, CLIPTokenizer +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.utils import ( + _get_model_file, + is_transformers_available, + logging, +) + +from .photomaker import PhotoMakerIDEncoder + +PipelineImageInput = Union[ + PIL.Image.Image, + torch.FloatTensor, + List[PIL.Image.Image], + List[torch.FloatTensor], +] + + +class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): + @validate_hf_hub_args + def load_photomaker_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: str = '', + trigger_word: str = 'img', + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + weight_name (`str`): + The weight name NOT the path to the weight. + + subfolder (`str`, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + trigger_word (`str`, *optional*, defaults to `"img"`): + The trigger word is used to identify the position of class word in the text prompt, + and it is recommended not to set it as a common word. + This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. + """ + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"id_encoder": {}, "lora_weights": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("id_encoder."): + state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) + elif key.startswith("lora_weights."): + state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["id_encoder", "lora_weights"]: + raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") + + self.trigger_word = trigger_word + # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet + print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") + id_encoder = PhotoMakerIDEncoder() + id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) + id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) + self.id_encoder = id_encoder + self.id_image_processor = CLIPImageProcessor() + + # load lora into models + print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") + self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + + # Add trigger word token + if self.tokenizer is not None: + self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) + + self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) + + def encode_prompt_with_trigger_word( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_id_images: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + class_tokens_mask: Optional[torch.LongTensor] = None, + ): + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Find the token id of the trigger word + image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + input_ids = tokenizer.encode(prompt) # TODO: batch encode + clean_index = 0 + clean_input_ids = [] + class_token_index = [] + # Find out the corrresponding class word token based on the newly added trigger word token + for i, token_id in enumerate(input_ids): + if token_id == image_token_id: + class_token_index.append(clean_index - 1) + else: + clean_input_ids.append(token_id) + clean_index += 1 + + if len(class_token_index) != 1: + raise ValueError( + f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ + Trigger word: {self.trigger_word}, Prompt: {prompt}." + ) + class_token_index = class_token_index[0] + + # Expand the class word token and corresponding mask + class_token = clean_input_ids[class_token_index] + clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \ + clean_input_ids[class_token_index + 1:] + + # Truncation or padding + max_len = tokenizer.model_max_length + if len(clean_input_ids) > max_len: + clean_input_ids = clean_input_ids[:max_len] + else: + clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( + max_len - len(clean_input_ids) + ) + + class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \ + for i in range(len(clean_input_ids))] + + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) + class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) + + prompt_embeds = text_encoder( + clean_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case + + return prompt_embeds, pooled_prompt_embeds, class_tokens_mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + # Added parameters (for PhotoMaker) + input_id_images: PipelineImageInput = None, + start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future + class_tokens_mask: Optional[torch.LongTensor] = None, + prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + Only the parameters introduced by PhotoMaker are discussed here. + For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py + + Args: + input_id_images (`PipelineImageInput`, *optional*): + Input ID Image to work with PhotoMaker. + class_tokens_mask (`torch.LongTensor`, *optional*): + Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. + prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + # + if prompt_embeds is not None and class_tokens_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." + ) + # check the input id images + if input_id_images is None: + raise ValueError( + "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." + ) + if not isinstance(input_id_images, list): + input_id_images = [input_id_images] + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + assert do_classifier_free_guidance + + # 3. Encode input prompt + num_id_images = len(input_id_images) + + ( + prompt_embeds, + pooled_prompt_embeds, + class_tokens_mask, + ) = self.encode_prompt_with_trigger_word( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_id_images=num_id_images, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + class_tokens_mask=class_tokens_mask, + ) + + # 4. Encode input prompt without the trigger word for delayed conditioning + prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space + ( + prompt_embeds_text_only, + negative_prompt_embeds, + pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt_text_only, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds_text_only, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds_text_only, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 5. Prepare the input ID images + dtype = next(self.id_encoder.parameters()).dtype + if not isinstance(input_id_images[0], torch.Tensor): + id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values + + id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts + + # 6. Get the update text embedding with the stacked ID embedding + prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # 7. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 8. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i <= start_merge_step: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds_text_only], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=current_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + # if self.watermark is not None: + # image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..c0509ee188f34e19a07546aa0dfd0606ff438426 --- /dev/null +++ b/toolkit/pipelines.py @@ -0,0 +1,1421 @@ +import importlib +import inspect +from typing import Union, List, Optional, Dict, Any, Tuple, Callable + +import numpy as np +import torch +from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from diffusers.utils import is_torch_xla_available +from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser +from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler +from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline): + + def __init__( + self, + vae: 'AutoencoderKL', + text_encoder: 'CLIPTextModel', + text_encoder_2: 'CLIPTextModelWithProjection', + tokenizer: 'CLIPTokenizer', + tokenizer_2: 'CLIPTokenizer', + unet: 'UNet2DConditionModel', + scheduler: 'KarrasDiffusionSchedulers', + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + raise NotImplementedError("This pipeline is not implemented yet") + # self.sampler = None + # scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + # model = ModelWrapper(unet, scheduler.alphas_cumprod) + # if scheduler.config.prediction_type == "v_prediction": + # self.k_diffusion_model = CompVisVDenoiser(model) + # else: + # self.k_diffusion_model = CompVisDenoiser(model) + + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + use_karras_sigmas: bool = False, + ): + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(prompt_embeds.dtype) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # noise_pred = self.unet( + # latent_model_input, + # t, + # encoder_hidden_states=prompt_embeds, + # cross_attention_kwargs=cross_attention_kwargs, + # added_cond_kwargs=added_cond_kwargs, + # return_dict=False, + # )[0] + + noise_pred = self.k_diffusion_model( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False,)[0] + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + + # 8. Run k-diffusion solver + sampler_kwargs = {} + # should work without it + noise_sampler_seed = None + + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + +class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): + + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + timestep: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # if not predict_noise: + # # call parent + # return super().__call__( + # prompt=prompt, + # prompt_2=prompt_2, + # height=height, + # width=width, + # num_inference_steps=num_inference_steps, + # denoising_end=denoising_end, + # guidance_scale=guidance_scale, + # negative_prompt=negative_prompt, + # negative_prompt_2=negative_prompt_2, + # num_images_per_prompt=num_images_per_prompt, + # eta=eta, + # generator=generator, + # latents=latents, + # prompt_embeds=prompt_embeds, + # negative_prompt_embeds=negative_prompt_embeds, + # pooled_prompt_embeds=pooled_prompt_embeds, + # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + # output_type=output_type, + # return_dict=return_dict, + # callback=callback, + # callback_steps=callback_steps, + # cross_attention_kwargs=cross_attention_kwargs, + # guidance_rescale=guidance_rescale, + # original_size=original_size, + # crops_coords_top_left=crops_coords_top_left, + # target_size=target_size, + # ) + + # 0. Default height and width to unet + height = self.default_sample_size * self.vae_scale_factor + width = self.default_sample_size * self.vae_scale_factor + + original_size = (height, width) + target_size = (height, width) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ).to(device) # TODO DOES NOT CAST ORIGINALLY + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred + + def enable_model_cpu_offload(self, gpu_id=0): + print('Called cpu offload', gpu_id) + # fuck off + pass + + +class CustomStableDiffusionPipeline(StableDiffusionPipeline): + + # replace the call so it matches SDXL call so we can use the same code and also stop early + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # some of the inputs are to keep it compatible with sdx + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + timestep: Optional[int] = None, + ): + + # 0. Default height and width to unet + height = self.unet.config.sample_size * self.vae_scale_factor + width = self.unet.config.sample_size * self.vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred + + +class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline): + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + denoising_start: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8.2 Determine denoising_start + denoising_start_index = 0 + if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1: + discrete_timestep_start = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps))) + + + with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar: + for i, t in enumerate(timesteps, start=denoising_start_index): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + + + +# TODO this is rough. Need to properly stack unconditional +class FluxWithCFGPipeline(FluxPipeline): + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + # bypass the guidance embedding if there is one + bypass_flux_guidance(self.transformer) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # todo combine these + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + restore_flux_guidance(self.transformer) + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/toolkit/progress_bar.py b/toolkit/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..e42f8086a7d29016beea66b09e8c0fdc574c5422 --- /dev/null +++ b/toolkit/progress_bar.py @@ -0,0 +1,25 @@ +from tqdm import tqdm +import time + + +class ToolkitProgressBar(tqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.paused = False + self.last_time = self._time() + + def pause(self): + if not self.paused: + self.paused = True + self.last_time = self._time() + + def unpause(self): + if self.paused: + self.paused = False + cur_t = self._time() + self.start_t += cur_t - self.last_time + self.last_print_t = cur_t + + def update(self, *args, **kwargs): + if not self.paused: + super().update(*args, **kwargs) diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a145841cc75e9fd96d9278d4eb8d78db262c7c8d --- /dev/null +++ b/toolkit/prompt_utils.py @@ -0,0 +1,561 @@ +import os +from typing import Optional, TYPE_CHECKING, List, Union, Tuple + +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import random + +from toolkit.train_tools import get_torch_dtype +import itertools + +if TYPE_CHECKING: + from toolkit.config_modules import SliderTargetConfig + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +class PromptEmbeds: + # text_embeds: torch.Tensor + # pooled_embeds: Union[torch.Tensor, None] + # attention_mask: Union[torch.Tensor, None] + + def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: + if isinstance(args, list) or isinstance(args, tuple): + # xl + self.text_embeds = args[0] + self.pooled_embeds = args[1] + else: + # sdv1.x, sdv2.x + self.text_embeds = args + self.pooled_embeds = None + + self.attention_mask = attention_mask + + def to(self, *args, **kwargs): + self.text_embeds = self.text_embeds.to(*args, **kwargs) + if self.pooled_embeds is not None: + self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.to(*args, **kwargs) + return self + + def detach(self): + new_embeds = self.clone() + new_embeds.text_embeds = new_embeds.text_embeds.detach() + if new_embeds.pooled_embeds is not None: + new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() + if new_embeds.attention_mask is not None: + new_embeds.attention_mask = new_embeds.attention_mask.detach() + return new_embeds + + def clone(self): + if self.pooled_embeds is not None: + prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) + else: + prompt_embeds = PromptEmbeds(self.text_embeds.clone()) + + if self.attention_mask is not None: + prompt_embeds.attention_mask = self.attention_mask.clone() + return prompt_embeds + + +class EncodedPromptPair: + def __init__( + self, + target_class, + target_class_with_neutral, + positive_target, + positive_target_with_neutral, + negative_target, + negative_target_with_neutral, + neutral, + empty_prompt, + both_targets, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + action_list=None, + multiplier=1.0, + multiplier_list=None, + weight=1.0, + target: 'SliderTargetConfig' = None, + ): + self.target_class: PromptEmbeds = target_class + self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral + self.positive_target: PromptEmbeds = positive_target + self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral + self.negative_target: PromptEmbeds = negative_target + self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral + self.neutral: PromptEmbeds = neutral + self.empty_prompt: PromptEmbeds = empty_prompt + self.both_targets: PromptEmbeds = both_targets + self.multiplier: float = multiplier + self.target: 'SliderTargetConfig' = target + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + self.action: int = action + if action_list is not None: + self.action_list: list[int] = action_list + else: + self.action_list: list[int] = [action] + self.weight: float = weight + + # simulate torch to for tensors + def to(self, *args, **kwargs): + self.target_class = self.target_class.to(*args, **kwargs) + self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs) + self.positive_target = self.positive_target.to(*args, **kwargs) + self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) + self.negative_target = self.negative_target.to(*args, **kwargs) + self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) + self.neutral = self.neutral.to(*args, **kwargs) + self.empty_prompt = self.empty_prompt.to(*args, **kwargs) + self.both_targets = self.both_targets.to(*args, **kwargs) + return self + + def detach(self): + self.target_class = self.target_class.detach() + self.target_class_with_neutral = self.target_class_with_neutral.detach() + self.positive_target = self.positive_target.detach() + self.positive_target_with_neutral = self.positive_target_with_neutral.detach() + self.negative_target = self.negative_target.detach() + self.negative_target_with_neutral = self.negative_target_with_neutral.detach() + self.neutral = self.neutral.detach() + self.empty_prompt = self.empty_prompt.detach() + self.both_targets = self.both_targets.detach() + return self + + +def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): + text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + pooled_embeds = None + if prompt_embeds[0].pooled_embeds is not None: + pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) + return PromptEmbeds([text_embeds, pooled_embeds]) + + +def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): + weight = prompt_pairs[0].weight + target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs]) + target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs]) + positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs]) + positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs]) + negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs]) + negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs]) + neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs]) + empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs]) + both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs]) + # combine all the lists + action_list = [] + multiplier_list = [] + weight_list = [] + for p in prompt_pairs: + action_list += p.action_list + multiplier_list += p.multiplier_list + return EncodedPromptPair( + target_class=target_class, + target_class_with_neutral=target_class_with_neutral, + positive_target=positive_target, + positive_target_with_neutral=positive_target_with_neutral, + negative_target=negative_target, + negative_target_with_neutral=negative_target_with_neutral, + neutral=neutral, + empty_prompt=empty_prompt, + both_targets=both_targets, + action_list=action_list, + multiplier_list=multiplier_list, + weight=weight, + target=prompt_pairs[0].target + ) + + +def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]: + if num_parts is None: + # use batch size + num_parts = concatenated.text_embeds.shape[0] + text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) + + if concatenated.pooled_embeds is not None: + pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0) + else: + pooled_embeds_splits = [None] * num_parts + + prompt_embeds_list = [ + PromptEmbeds([text, pooled]) + for text, pooled in zip(text_embeds_splits, pooled_embeds_splits) + ] + + return prompt_embeds_list + + +def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]: + target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds) + target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds) + positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds) + positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds) + negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds) + negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds) + neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds) + empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds) + both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds) + + prompt_pairs = [] + for i in range(len(target_class_splits)): + action_list_split = concatenated.action_list[i::len(target_class_splits)] + multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)] + + prompt_pair = EncodedPromptPair( + target_class=target_class_splits[i], + target_class_with_neutral=target_class_with_neutral_splits[i], + positive_target=positive_target_splits[i], + positive_target_with_neutral=positive_target_with_neutral_splits[i], + negative_target=negative_target_splits[i], + negative_target_with_neutral=negative_target_with_neutral_splits[i], + neutral=neutral_splits[i], + empty_prompt=empty_prompt_splits[i], + both_targets=both_targets_splits[i], + action_list=action_list_split, + multiplier_list=multiplier_list_split, + weight=concatenated.weight, + target=concatenated.target + ) + prompt_pairs.append(prompt_pair) + + return prompt_pairs + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0, + multiplier_list=None + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + + def to(self, *args, **kwargs): + self.prompt = self.prompt.to(*args, **kwargs) + self.neg_prompt = self.neg_prompt.to(*args, **kwargs) + return self + + +def concat_anchors(anchors: list[EncodedAnchor]): + prompt = concat_prompt_embeds([a.prompt for a in anchors]) + neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors]) + return EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier_list=[a.multiplier for a in anchors] + ) + + +def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]: + prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors) + neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors) + multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors) + + anchors = [] + for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits): + anchor = EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier=multiplier.tolist() + ) + anchors.append(anchor) + + return anchors + + +def get_permutations(s, max_permutations=8): + # Split the string by comma + phrases = [phrase.strip() for phrase in s.split(',')] + + # remove empty strings + phrases = [phrase for phrase in phrases if len(phrase) > 0] + # shuffle the list + random.shuffle(phrases) + + # Get all permutations + permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)]) + + # Convert the tuples back to comma separated strings + return [', '.join(permutation) for permutation in permutations] + + +def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: + from toolkit.config_modules import SliderTargetConfig + pos_permutations = get_permutations(target.positive, max_permutations=max_permutations) + neg_permutations = get_permutations(target.negative, max_permutations=max_permutations) + + permutations = [] + for pos, neg in itertools.product(pos_permutations, neg_permutations): + permutations.append( + SliderTargetConfig( + target_class=target.target_class, + positive=pos, + negative=neg, + multiplier=target.multiplier, + weight=target.weight + ) + ) + + # shuffle the list + random.shuffle(permutations) + + if len(permutations) > max_permutations: + permutations = permutations[:max_permutations] + + return permutations + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +@torch.no_grad() +def encode_prompts_to_cache( + prompt_list: list[str], + sd: "StableDiffusion", + cache: Optional[PromptEmbedsCache] = None, + prompt_tensor_file: Optional[str] = None, +) -> PromptEmbedsCache: + # TODO: add support for larger prompts + if cache is None: + cache = PromptEmbedsCache() + + if prompt_tensor_file is not None: + # check to see if it exists + if os.path.exists(prompt_tensor_file): + # load it. + print(f"Loading prompt tensors from {prompt_tensor_file}") + prompt_tensors = load_file(prompt_tensor_file, device='cpu') + # add them to the cache + for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): + if prompt_txt.startswith("te:"): + prompt = prompt_txt[3:] + # text_embeds + text_embeds = prompt_tensor + pooled_embeds = None + # find pool embeds + if f"pe:{prompt}" in prompt_tensors: + pooled_embeds = prompt_tensors[f"pe:{prompt}"] + + # make it + prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) + cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) + + if len(cache.prompts) == 0: + print("Prompt tensors not found. Encoding prompts..") + empty_prompt = "" + # encode empty_prompt + cache[empty_prompt] = sd.encode_prompt(empty_prompt) + + for p in tqdm(prompt_list, desc="Encoding prompts", leave=False): + # build the cache + if cache[p] is None: + cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16) + + # should we shard? It can get large + if prompt_tensor_file: + print(f"Saving prompt tensors to {prompt_tensor_file}") + state_dict = {} + for prompt_txt, prompt_embeds in cache.prompts.items(): + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to( + "cpu", dtype=get_torch_dtype('fp16') + ) + if prompt_embeds.pooled_embeds is not None: + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to( + "cpu", + dtype=get_torch_dtype('fp16') + ) + save_file(state_dict, prompt_tensor_file) + + return cache + + +@torch.no_grad() +def build_prompt_pair_batch_from_cache( + cache: PromptEmbedsCache, + target: 'SliderTargetConfig', + neutral: Optional[str] = '', +) -> list[EncodedPromptPair]: + erase_negative = len(target.positive.strip()) == 0 + enhance_positive = len(target.negative.strip()) == 0 + + both = not erase_negative and not enhance_positive + + prompt_pair_batch = [] + + if both or erase_negative: + # print("Encoding erase negative") + prompt_pair_batch += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight, + target=target + ), + ] + if both or enhance_positive: + # print("Encoding enhance positive") + prompt_pair_batch += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight, + target=target + ), + ] + if both or enhance_positive: + # print("Encoding erase positive (inverse)") + prompt_pair_batch += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight, + target=target + ), + ] + if both or erase_negative: + # print("Encoding enhance negative (inverse)") + prompt_pair_batch += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + both_targets=cache[f"{target.positive} {target.negative}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight, + target=target + ), + ] + + return prompt_pair_batch + + +def build_latent_image_batch_for_prompt_pair( + pos_latent, + neg_latent, + prompt_pair: EncodedPromptPair, + prompt_chunk_size +): + erase_negative = len(prompt_pair.target.positive.strip()) == 0 + enhance_positive = len(prompt_pair.target.negative.strip()) == 0 + both = not erase_negative and not enhance_positive + + prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size) + if both and len(prompt_pair_chunks) != 4: + raise Exception("Invalid prompt pair chunks") + if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2: + raise Exception("Invalid prompt pair chunks") + + latent_list = [] + + if both or erase_negative: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(neg_latent) + if both or erase_negative: + latent_list.append(neg_latent) + + return torch.cat(latent_list, dim=0) + + +def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True): + if trigger is None: + # process as empty string to remove any [trigger] tokens + trigger = '' + output_prompt = prompt + default_replacements = ["[name]", "[trigger]"] + + replace_with = trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + if trigger.strip() != "": + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + # if num_instances > 1: + # print( + # f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/toolkit/reference_adapter.py b/toolkit/reference_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..d00dfb72974d03917723a4ef54caee7f32dbcdd1 --- /dev/null +++ b/toolkit/reference_adapter.py @@ -0,0 +1,410 @@ +import math + +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.basic import adain +from toolkit.paths import REPOS_ROOT +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import Resampler +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from diffusers import ( + EulerDiscreteScheduler, + DDPMScheduler, +) + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification + +import torch.nn.functional as F +import torch.nn as nn + + +class ReferenceAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.ref_net = nn.Linear(hidden_size, hidden_size) + self.blend = nn.Parameter(torch.zeros(hidden_size)) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self._memory = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + if self.adapter_ref().is_active: + if self.adapter_ref().reference_mode == "write": + # write_mode + memory_ref = self.ref_net(hidden_states) + self._memory = memory_ref + elif self.adapter_ref().reference_mode == "read": + # read_mode + if self._memory is None: + print("Warning: no memory to read from") + else: + + saved_hidden_states = self._memory + try: + new_hidden_states = saved_hidden_states + blend = self.blend + # expand the blend buyt keep dim 0 the same (batch) + while blend.ndim < new_hidden_states.ndim: + blend = blend.unsqueeze(0) + # expand batch + blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0) + hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states + except Exception as e: + raise Exception(f"Error blending: {e}") + + return hidden_states + + +class ReferenceAdapter(torch.nn.Module): + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.reference_mode = "read" + self.current_scale = 1.0 + self.is_active = True + self._reference_images = None + self._reference_latents = None + self.has_memory = False + + self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + for name in sd.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + # layer_name = name.split(".processor")[0] + # weights = { + # "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + # "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + # } + + attn_procs[name] = ReferenceAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self + ) + # attn_procs[name].load_state_dict(weights) + sd.unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.adapter_modules = adapter_modules + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) + self.attach() + self.to(self.device, self.sd_ref().torch_dtype) + + # if self.config.train_image_encoder: + # self.image_encoder.train() + # self.image_encoder.requires_grad_(True) + + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + # self.image_encoder.to(*args, **kwargs) + # self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + return self + + def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]): + reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + reference_layers.load_state_dict(state_dict["reference_adapter"]) + + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + state_dict["reference_adapter"] = self.adapter_modules.state_dict() + return state_dict + + def get_scale(self): + return self.current_scale + + def set_reference_images(self, reference_images: Optional[torch.Tensor]): + self._reference_images = reference_images.clone().detach() + self._reference_latents = None + self.clear_memory() + + def set_blank_reference_images(self, batch_size): + self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype) + self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype) + self.clear_memory() + + + def set_scale(self, scale): + self.current_scale = scale + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor.scale = scale + + + def attach(self): + unet = self.sd_ref().unet + self._original_unet_forward = unet.forward + unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs) + if self.sd_ref().network is not None: + # set network to not merge in + self.sd_ref().network.can_merge_in = False + + def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + skip = False + if self._reference_images is None and self._reference_latents is None: + skip = True + if not self.is_active: + skip = True + + if self.has_memory: + skip = True + + if not skip: + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = True + if self.sd_ref().network.is_merged_in: + raise ValueError("network is merged in, but we are not supposed to be merged in") + # send it through our forward first + self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs) + + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = False + + # Send it through the original unet forward + return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs) + + + # use drop for prompt dropout, or negatives + def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + if not self.noise_scheduler: + raise ValueError("noise scheduler not set") + if not self.is_active or (self._reference_images is None and self._reference_latents is None): + raise ValueError("reference adapter not active or no reference images set") + # todo may need to handle cfg? + self.reference_mode = "write" + + if self._reference_latents is None: + self._reference_latents = self.sd_ref().encode_images(self._reference_images.to( + self.device, self.sd_ref().torch_dtype + )).detach() + # create a sample from our reference images + reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype) + # if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional) + if reference_latents.shape[0] * 2 == sample.shape[0]: + # we are doing cfg + # Unconditional goes first + reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach() + + # resize it so reference_latents will fit inside sample in the center + width_scale = sample.shape[2] / reference_latents.shape[2] + height_scale = sample.shape[3] / reference_latents.shape[3] + scale = min(width_scale, height_scale) + # resize the reference latents + + mode = "bilinear" if scale > 1.0 else "bicubic" + + reference_latents = F.interpolate( + reference_latents, + size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)), + mode=mode, + align_corners=False + ) + + # add 0 padding if needed + width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2 + height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2 + reference_latents = F.pad( + reference_latents, + (math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)), + mode="constant", + value=0 + ) + + # resize again just to make sure it is exact same size + reference_latents = F.interpolate( + reference_latents, + size=(sample.shape[2], sample.shape[3]), + mode="bicubic", + align_corners=False + ) + + # todo maybe add same noise to the sample? For now we will send it through with no noise + # sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep) + self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs) + self.reference_mode = "read" + self.has_memory = True + return None + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for attn_processor in self.adapter_modules: + yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # self.image_encoder.train() + # else: + # for attn_processor in self.adapter_modules: + # yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict) + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def clear_memory(self): + for attn_processor in self.adapter_modules: + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor._memory = None + self.has_memory = False diff --git a/toolkit/resampler.py b/toolkit/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9ace5a3a18d78e6f5b712dd587aaa45827247dc6 --- /dev/null +++ b/toolkit/resampler.py @@ -0,0 +1,160 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py +# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, + # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/toolkit/sampler.py b/toolkit/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b0311b3b4c6b788bcf18c78df2c86a134465c0 --- /dev/null +++ b/toolkit/sampler.py @@ -0,0 +1,164 @@ +import copy +import math + +from diffusers import ( + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + LCMScheduler, + FlowMatchEulerDiscreteScheduler, +) + +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +from k_diffusion.external import CompVisDenoiser + +from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +sd_config = { + "_class_name": "EulerAncestralDiscreteScheduler", + "_diffusers_version": "0.24.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + # "skip_prk_steps": False, # for training + "skip_prk_steps": True, + # "steps_offset": 1, + "steps_offset": 0, + # "timestep_spacing": "trailing", # for training + "timestep_spacing": "leading", + "trained_betas": None +} + +pixart_config = { + "_class_name": "DPMSolverMultistepScheduler", + "_diffusers_version": "0.22.0.dev0", + "algorithm_type": "dpmsolver++", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "dynamic_thresholding_ratio": 0.995, + "euler_at_final": False, + # "lambda_min_clipped": -Infinity, + "lambda_min_clipped": -math.inf, + "lower_order_final": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "solver_order": 2, + "solver_type": "midpoint", + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "linspace", + "trained_betas": None, + "use_karras_sigmas": False, + "use_lu_lambdas": False, + "variance_type": None +} + + +def get_sampler( + sampler: str, + kwargs: dict = None, + arch: str = "sd" +): + sched_init_args = {} + if kwargs is not None: + sched_init_args.update(kwargs) + + config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config) + + if sampler.startswith("k_"): + sched_init_args["use_karras_sigmas"] = True + + if sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sampler == "ddpm": # ddpm is not supported ? + scheduler_cls = DDPMScheduler + elif sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sampler == "lms" or sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sampler == "euler" or sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sampler == "euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sampler.replace("k_", "") + elif sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sampler == "dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sampler == "dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + elif sampler == "lcm": + scheduler_cls = LCMScheduler + elif sampler == "custom_lcm": + scheduler_cls = CustomLCMScheduler + elif sampler == "flowmatch": + scheduler_cls = CustomFlowMatchEulerDiscreteScheduler + config_to_use = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.29.0.dev0", + "num_train_timesteps": 1000, + "shift": 3.0 + } + else: + raise ValueError(f"Sampler {sampler} not supported") + + + config = copy.deepcopy(config_to_use) + config.update(sched_init_args) + + scheduler = scheduler_cls.from_config(config) + + return scheduler + + +# testing +if __name__ == "__main__": + from diffusers import DiffusionPipeline + + from diffusers import StableDiffusionKDiffusionPipeline + import torch + import os + + inference_steps = 25 + + pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + pipe = pipe.to("cuda") + + k_diffusion_model = CompVisDenoiser(model) + + pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") + pipe = pipe.to("cuda") + + prompt = "an astronaut riding a horse on mars" + pipe.set_scheduler("sample_heun") + generator = torch.Generator(device="cuda").manual_seed(seed) + image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] + + image.save("./astronaut_heun_k_diffusion.png") diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..440eb4fa1220fab83a0d4ce3cdae928b873031ff --- /dev/null +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -0,0 +1,110 @@ +import math +from typing import Union + +from diffusers import FlowMatchEulerDiscreteScheduler +import torch + + +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_noise_sigma = 1.0 + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + # Bell-Shaped Mean-Normalized Timestep Weighting + # bsmntw? need a better name + + x = torch.arange(num_timesteps, dtype=torch.float32) + y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2) + + # Shift minimum to 0 + y_shifted = y - y.min() + + # Scale to make mean 1 + bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + + # only do half bell + hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + + # flatten second half to max + hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + self.linear_timesteps = timesteps + self.linear_timesteps_weights = bsmntw_weighing + self.linear_timesteps_weights2 = hbsmntw_weighing + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + if v2: + weights = self.linear_timesteps_weights2[step_indices].flatten() + else: + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps diff --git a/toolkit/samplers/custom_lcm_scheduler.py b/toolkit/samplers/custom_lcm_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..132052af74186b2597060d66c764c8b4ed841378 --- /dev/null +++ b/toolkit/samplers/custom_lcm_scheduler.py @@ -0,0 +1,553 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class CustomLCMScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config + attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be + accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving + functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + original_inference_steps (`int`, *optional*, defaults to 50): + The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we + will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + self.original_inference_steps = 50 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self.train_timesteps = 1000 + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + strength: int = 1.0, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + original_inference_steps (`int`, *optional*): + The original number of inference steps, which will be used to generate a linearly-spaced timestep + schedule (which is different from the standard `diffusers` implementation). We will then take + `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as + our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + """ + + original_inference_steps = self.original_inference_steps + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps + ) + + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Timesteps Setting + # The skipping step parameter k from the paper. + k = self.config.num_train_timesteps // original_steps + # LCM Training/Distillation Steps Schedule + # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + # LCM Inference Steps Schedule + lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from lcm_origin_timesteps. + inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = lcm_origin_timesteps[inference_indices] + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.config.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/toolkit/saving.py b/toolkit/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..7abc7d5058d347b06cdcfcd4b452223e7920b346 --- /dev/null +++ b/toolkit/saving.py @@ -0,0 +1,330 @@ +import json +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Literal, Optional, Union + +import torch +from safetensors.torch import load_file, save_file + +from toolkit.train_tools import get_torch_dtype +from toolkit.paths import KEYMAPS_ROOT + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +def get_slices_from_string(s: str) -> tuple: + slice_strings = s.split(',') + slices = [eval(f"slice({component.strip()})") for component in slice_strings] + return tuple(slices) + + +def convert_state_dict_to_ldm_with_mapping( + diffusers_state_dict: 'OrderedDict', + mapping_path: str, + base_path: Union[str, None] = None, + device: str = 'cpu', + dtype: torch.dtype = torch.float32 +) -> 'OrderedDict': + converted_state_dict = OrderedDict() + + # load mapping + with open(mapping_path, 'r') as f: + mapping = json.load(f, object_pairs_hook=OrderedDict) + + # keep track of keys not matched + ldm_matched_keys = [] + diffusers_matched_keys = [] + + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map'] + ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map'] + + # load base if it exists + # the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match + if base_path is not None: + converted_state_dict = load_file(base_path, device) + # convert to the right dtype + for key in converted_state_dict: + converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype) + + # process operators first + for ldm_key in ldm_diffusers_operator_map: + # if the key cat is in the ldm key, we need to process it + if 'cat' in ldm_diffusers_operator_map[ldm_key]: + cat_list = [] + for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']: + cat_list.append(diffusers_state_dict[diffusers_key].detach()) + converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype) + diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat']) + ldm_matched_keys.append(ldm_key) + if 'slice' in ldm_diffusers_operator_map[ldm_key]: + tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]] + slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]] + converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device, + dtype=dtype) + diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice']) + ldm_matched_keys.append(ldm_key) + + # process the rest of the keys + for ldm_key in ldm_diffusers_keymap: + # if the key is in the ldm key, we need to process it + if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict: + tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype) + # see if we need to reshape + if ldm_key in ldm_diffusers_shape_map: + tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0]) + converted_state_dict[ldm_key] = tensor + diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key]) + ldm_matched_keys.append(ldm_key) + + # see if any are missing from know mapping + mapped_diffusers_keys = list(ldm_diffusers_keymap.values()) + mapped_ldm_keys = list(ldm_diffusers_keymap.keys()) + + missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys] + missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys] + + if len(missing_diffusers_keys) > 0: + print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys") + print(missing_diffusers_keys) + if len(missing_ldm_keys) > 0: + print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys") + print(missing_ldm_keys) + + return converted_state_dict + + +def get_ldm_state_dict_from_diffusers( + state_dict: 'OrderedDict', + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2', + device='cpu', + dtype=get_torch_dtype('fp32'), +): + if sd_version == '1': + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json') + elif sd_version == '2': + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json') + elif sd_version == 'sdxl': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json') + elif sd_version == 'ssd': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json') + elif sd_version == 'vega': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json') + elif sd_version == 'sdxl_refiner': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json') + else: + raise ValueError(f"Invalid sd_version {sd_version}") + + # convert the state dict + return convert_state_dict_to_ldm_with_mapping( + state_dict, + mapping_path, + base_path, + device=device, + dtype=dtype + ) + + +def save_ldm_model_from_diffusers( + sd: 'StableDiffusion', + output_file: str, + meta: 'OrderedDict', + save_dtype=get_torch_dtype('fp16'), + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2' +): + converted_state_dict = get_ldm_state_dict_from_diffusers( + sd.state_dict(), + sd_version, + device='cpu', + dtype=save_dtype + ) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def save_lora_from_diffusers( + lora_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + save_dtype=get_torch_dtype('fp16'), + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2' +): + converted_state_dict = OrderedDict() + # only handle sxdxl for now + if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega': + raise ValueError(f"Invalid sd_version {sd_version}") + for key, value in lora_state_dict.items(): + # todo verify if this works with ssd + # test encoders share keys for some reason + if key.begins_with('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def save_t2i_from_diffusers( + t2i_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), +): + # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() + for key, value in t2i_state_dict.items(): + converted_state_dict[key] = value.detach().to('cpu', dtype=dtype) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_t2i_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + raw_state_dict = load_file(path_to_file, device) + converted_state_dict = OrderedDict() + for key, value in raw_state_dict.items(): + # todo see if we need to convert dict + converted_state_dict[key] = value.detach().to(device, dtype=dtype) + return converted_state_dict + + + + +def save_ip_adapter_from_diffusers( + combined_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), + direct_save: bool = False +): + # todo: test compatibility with non diffusers + + converted_state_dict = OrderedDict() + for module_name, state_dict in combined_state_dict.items(): + if direct_save: + converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype) + else: + for key, value in state_dict.items(): + converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_ip_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32, + direct_load: bool = False +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + if direct_load: + return raw_state_dict + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) + +def load_custom_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + device = device if isinstance(device, torch.device) else torch.device(device) + dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype) + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) + + +def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict': + lora_keymap = OrderedDict() + + # see if we have dual text encoders " a key that starts with conditioner.embedders.1 + has_dual_text_encoders = False + for key in model_keymap: + if key.startswith('conditioner.embedders.1'): + has_dual_text_encoders = True + break + # map through the keys and values + for key, value in model_keymap.items(): + # ignore bias weights + if key.endswith('bias'): + continue + if key.endswith('.weight'): + # remove the .weight + key = key[:-7] + if value.endswith(".weight"): + # remove the .weight + value = value[:-7] + + # unet for all + key = key.replace('model.diffusion_model', 'lora_unet') + if value.startswith('unet'): + value = f"lora_{value}" + + # text encoder + if has_dual_text_encoders: + key = key.replace('conditioner.embedders.0', 'lora_te1') + key = key.replace('conditioner.embedders.1', 'lora_te2') + if value.startswith('te0') or value.startswith('te1'): + value = f"lora_{value}" + value.replace('lora_te1', 'lora_te2') + value.replace('lora_te0', 'lora_te1') + + key = key.replace('cond_stage_model.transformer', 'lora_te') + + if value.startswith('te_'): + value = f"lora_{value}" + + # replace periods with underscores + key = key.replace('.', '_') + value = value.replace('.', '_') + + # add all the weights + lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight" + lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias" + lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight" + lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias" + lora_keymap[f"{key}.alpha"] = f"{value}.alpha" + + return lora_keymap diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f8f61aeb8f63b12ee8f8f2800385c11ec3b7bf --- /dev/null +++ b/toolkit/scheduler.py @@ -0,0 +1,57 @@ +import torch +from typing import Optional +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup + + +def get_lr_scheduler( + name: Optional[str], + optimizer: torch.optim.Optimizer, + **kwargs, +): + if name == "cosine": + if 'total_iters' in kwargs: + kwargs['T_max'] = kwargs.pop('total_iters') + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, **kwargs + ) + elif name == "cosine_with_restarts": + if 'total_iters' in kwargs: + kwargs['T_0'] = kwargs.pop('total_iters') + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, **kwargs + ) + elif name == "step": + + return torch.optim.lr_scheduler.StepLR( + optimizer, **kwargs + ) + elif name == "constant": + if 'factor' not in kwargs: + kwargs['factor'] = 1.0 + + return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs) + elif name == "linear": + + return torch.optim.lr_scheduler.LinearLR( + optimizer, **kwargs + ) + elif name == 'constant_with_warmup': + # see if num_warmup_steps is in kwargs + if 'num_warmup_steps' not in kwargs: + print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000") + kwargs['num_warmup_steps'] = 1000 + del kwargs['total_iters'] + return get_constant_schedule_with_warmup(optimizer, **kwargs) + else: + # try to use a diffusers scheduler + print(f"Trying to use diffusers scheduler {name}") + try: + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **kwargs) + except Exception as e: + print(e) + pass + raise ValueError( + "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" + ) diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..1eeecc323fefb7b06fdaf30ff9f80f5399b0ce09 --- /dev/null +++ b/toolkit/sd_device_states_presets.py @@ -0,0 +1,107 @@ +from typing import Union + +import torch +import copy + +empty_preset = { + 'vae': { + 'training': False, + 'device': 'cpu', + }, + 'unet': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'text_encoder': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'adapter': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'refiner_unet': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, +} + + +def get_train_sd_device_state_preset( + device: Union[str, torch.device], + train_unet: bool = False, + train_text_encoder: bool = False, + cached_latents: bool = False, + train_lora: bool = False, + train_adapter: bool = False, + train_embedding: bool = False, + train_decorator: bool = False, + train_refiner: bool = False, + unload_text_encoder: bool = False, + require_grads: bool = True, +): + preset = copy.deepcopy(empty_preset) + if not cached_latents: + preset['vae']['device'] = device + + if train_unet: + preset['unet']['training'] = True + preset['unet']['requires_grad'] = require_grads + preset['unet']['device'] = device + else: + preset['unet']['device'] = device + + if train_text_encoder: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = require_grads + preset['text_encoder']['device'] = device + else: + preset['text_encoder']['device'] = device + + if train_embedding: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = require_grads + preset['text_encoder']['training'] = True + preset['unet']['training'] = True + + if train_refiner: + preset['refiner_unet']['training'] = True + preset['refiner_unet']['requires_grad'] = require_grads + preset['refiner_unet']['device'] = device + # if not training unet, move that to cpu + if not train_unet: + preset['unet']['device'] = 'cpu' + + if train_lora: + # preset['text_encoder']['requires_grad'] = False + preset['unet']['requires_grad'] = False + if train_refiner: + preset['refiner_unet']['requires_grad'] = False + + if train_adapter: + preset['adapter']['requires_grad'] = require_grads + preset['adapter']['training'] = True + preset['adapter']['device'] = device + preset['unet']['training'] = True + preset['unet']['requires_grad'] = False + preset['unet']['device'] = device + preset['text_encoder']['device'] = device + + if train_decorator: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = device + preset['unet']['training'] = True + preset['unet']['requires_grad'] = False + preset['unet']['device'] = device + + if unload_text_encoder: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = 'cpu' + + return preset diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..23439fbc310d5153528e99114f31ca5c2988de0a --- /dev/null +++ b/toolkit/stable_diffusion_model.py @@ -0,0 +1,2754 @@ +import copy +import gc +import json +import random +import shutil +import typing +from typing import Union, List, Literal, Iterator +import sys +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ + ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from safetensors.torch import save_file, load_file +from torch import autocast +from torch.nn import Parameter +from torch.utils.checkpoint import checkpoint +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.assistant_lora import load_assistant_lora_from_path +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.ip_adapter import IPAdapter +from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ + convert_vae_state_dict, load_vae +from toolkit import train_tools +from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.decorator import Decorator +from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sampler import get_sampler +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from einops import rearrange, repeat +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ + StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ + StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline +from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT +from huggingface_hub import hf_hub_download +from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance + +from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + + +class StableDiffusion: + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + quantize_device=None, + ): + self.custom_pipeline = custom_pipeline + self.device = device + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = torch.device(self.device) + + self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( + model_config.vae_device) + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( + model_config.te_device) + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.unet: Union[None, 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None + self.is_xl = model_config.is_xl + self.is_v2 = model_config.is_v2 + self.is_ssd = model_config.is_ssd + self.is_v3 = model_config.is_v3 + self.is_vega = model_config.is_vega + self.is_pixart = model_config.is_pixart + self.is_auraflow = model_config.is_auraflow + self.is_flux = model_config.is_flux + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): + self.is_flow_matching = True + + self.quantize_device = quantize_device if quantize_device is not None else self.device + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + + def load_model(self): + if self.is_loaded: + return + dtype = get_torch_dtype(self.dtype) + + # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why + # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + + model_path = self.model_config.name_or_path + if 'civitai.com' in self.model_config.name_or_path: + # load is a civit ai model, use the loader. + from toolkit.civitai import get_model_path_from_url + model_path = get_model_path_from_url(self.model_config.name_or_path) + + load_args = {} + if self.noise_scheduler: + load_args['scheduler'] = self.noise_scheduler + + if self.model_config.vae_path is not None: + load_args['vae'] = load_vae(self.model_config.vae_path, dtype) + if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionXLPipeline + # pipln = StableDiffusionKDiffusionXLPipeline + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + ) + + if 'vae' in load_args and load_args['vae'] is not None: + pipe.vae = load_args['vae'] + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + for text_encoder in text_encoders: + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + + if self.model_config.experimental_xl: + print("Experimental XL mode enabled") + print("Loading and injecting alt weights") + # load the mismatched weight and force it in + raw_state_dict = load_file(model_path) + replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() + del raw_state_dict + # get state dict for for 2nd text encoder + te1_state_dict = text_encoders[1].state_dict() + # replace weight with mismatched weight + te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) + flush() + print("Injecting alt weights") + elif self.model_config.is_v3: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusion3Pipeline + + print("Loading SD3 model") + # assume it is the large model + base_model_path = "stabilityai/stable-diffusion-3.5-large" + print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + else: + # is remote use whatever path we were given + base_model_path = model_path + + transformer = SD3Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(torch.device(self.quantize_device), dtype=dtype) + flush() + + if self.model_config.lora_path is not None: + raise ValueError("LoRA is not supported for SD3 models currently") + + if self.model_config.quantize: + quantization_type = qfloat8 + print("Quantizing transformer") + quantize(transformer, weights=quantization_type) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print("Loading t5") + tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) + text_encoder_3 = T5EncoderModel.from_pretrained( + base_model_path, + subfolder="text_encoder_3", + torch_dtype=dtype + ) + + text_encoder_3.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_3, weights=qfloat8) + freeze(text_encoder_3) + flush() + + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + try: + # try to load with default diffusers + pipe = pipln.from_pretrained( + base_model_path, + dtype=dtype, + device=self.device_torch, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + transformer=transformer, + # variant="fp16", + use_safetensors=True, + repo_type="model", + ignore_patterns=["*.md", "*..gitattributes"], + **load_args + ) + except Exception as e: + print(f"Error loading from pretrained: {e}") + raise e + + else: + pipe = pipln.from_single_file( + model_path, + transformer=transformer, + device=self.device_torch, + torch_dtype=self.torch_dtype, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + **load_args + ) + + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] + # replace the to function with a no-op since it throws an error instead of a warning + # text_encoders[2].to = lambda *args, **kwargs: None + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + + elif self.model_config.is_pixart: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" + if self.model_config.is_pixart_sigma: + main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = T5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + + if self.model_config.is_pixart_sigma: + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + else: + + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, + subfolder=subfolder) + pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ).to(self.device_torch) + + if self.model_config.unet_sample_size is not None: + pipe.transformer.config.sample_size = self.model_config.unet_sample_size + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + if self.noise_scheduler is None: + self.noise_scheduler = pipe.scheduler + + + elif self.model_config.is_auraflow: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = UMT5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + # load the transformer only from the save + transformer = AuraFlowTransformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + # patch auraflow so it can handle other aspect ratios + # patch_auraflow_pos_embed(pipe.transformer.pos_embed) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + elif self.model_config.is_flux: + print("Loading Flux model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + local_files_only = False + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + # low_cpu_mem_usage=False, + # device_map=None + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(torch.device(self.quantize_device), dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: + raise ValueError("Cannot load both assistant lora and inference lora at the same time") + + if self.model_config.lora_path: + raise ValueError("Cannot load both assistant lora and lora at the same time") + + if not self.is_flux: + raise ValueError("Assistant/ inference lora is only supported for flux models currently") + + load_lora_path = self.model_config.inference_lora_path + if load_lora_path is None: + load_lora_path = self.model_config.assistant_lora_path + + if os.path.isdir(load_lora_path): + load_lora_path = os.path.join( + load_lora_path, "pytorch_lora_weights.safetensors" + ) + elif not os.path.exists(load_lora_path): + print(f"Grabbing lora from the hub: {load_lora_path}") + new_lora_path = hf_hub_download( + load_lora_path, + filename="pytorch_lora_weights.safetensors" + ) + # replace the path + load_lora_path = new_lora_path + + if self.model_config.inference_lora_path is not None: + self.model_config.inference_lora_path = new_lora_path + if self.model_config.assistant_lora_path is not None: + self.model_config.assistant_lora_path = new_lora_path + + if self.model_config.assistant_lora_path is not None: + # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on + # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps + # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half + # so we will merge in now and sample with -1 weight later + self.invert_assistant_lora = True + # trigger it to get merged in + self.model_config.lora_path = self.model_config.assistant_lora_path + + if self.model_config.lora_path is not None: + print("Fusing in LoRA") + # need the pipe for peft + pipe: FluxPipeline = FluxPipeline( + scheduler=None, + text_encoder=None, + tokenizer=None, + text_encoder_2=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + ) + if self.low_vram: + # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts + # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu + # we are going to separate it into the two transformer blocks one at a time + + lora_state_dict = load_file(self.model_config.lora_path) + single_transformer_lora = {} + single_block_key = "transformer.single_transformer_blocks." + double_transformer_lora = {} + double_block_key = "transformer.transformer_blocks." + for key, value in lora_state_dict.items(): + if single_block_key in key: + single_transformer_lora[key] = value + elif double_block_key in key: + double_transformer_lora[key] = value + else: + raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") + + # double blocks + transformer.transformer_blocks = transformer.transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.transformer_blocks = transformer.transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # single blocks + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # cleanup + del single_transformer_lora + del double_transformer_lora + del lora_state_dict + flush() + + else: + # need the pipe to do this unfortunately for now + # we have to fuse in the weights before quantizing + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + print("Quantizing transformer") + quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print("Loading t5") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", + torch_dtype=dtype) + + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + flush() + + print("Loading clip") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + print("making pipe") + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + print("preparing") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + else: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionPipeline + + if self.model_config.text_encoder_bits < 16: + # this is only supported for T5 models for now + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + torch_dtype=self.te_torch_dtype, + **te_kwargs + ) + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + load_args['text_encoder'] = text_encoder + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + safety_checker=None, + # variant="fp16", + trust_remote_code=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + torch_dtype=self.torch_dtype, + safety_checker=None, + trust_remote_code=True, + **load_args + ) + flush() + + pipe.register_to_config(requires_safety_checker=False) + text_encoder = pipe.text_encoder + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + tokenizer = pipe.tokenizer + + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = self.noise_scheduler + + # add hacks to unet to help training + # pipe.unet = prepare_unet_for_training(pipe.unet) + + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + # pixart and sd3 dont use a unet + self.unet = pipe.transformer + else: + self.unet: 'UNet2DConditionModel' = pipe.unet + self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + self.vae.eval() + self.vae.requires_grad_(False) + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + self.vae_scale_factor = VAE_SCALE_FACTOR + self.unet.to(self.device_torch, dtype=dtype) + self.unet.requires_grad_(False) + self.unet.eval() + + # load any loras we have + if self.model_config.lora_path is not None and not self.is_flux: + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.pipeline = pipe + self.load_refiner() + self.is_loaded = True + + if self.model_config.assistant_lora_path is not None: + print("Loading assistant lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.assistant_lora_path, self) + + if self.invert_assistant_lora: + # invert and disable during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.inference_lora_path, self) + # disable during training + self.assistant_lora.is_active = False + + if self.is_pixart and self.vae_scale_factor == 16: + # TODO make our own pipeline? + # we generate an image 2x larger, so we need to copy the sizes from larger ones down + # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN + for key in ASPECT_RATIO_256_BIN.keys(): + ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] + for key in ASPECT_RATIO_512_BIN.keys(): + ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] + for key in ASPECT_RATIO_1024_BIN.keys(): + ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] + for key in ASPECT_RATIO_2048_BIN.keys(): + ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + else: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + else: + self.text_encoder.eval() + + def load_refiner(self): + # for now, we are just going to rely on the TE from the base model + # which is TE2 for SDXL and TE for SD (no refiner currently) + # and completely ignore a TE that may or may not be packaged with the refiner + if self.model_config.refiner_name_or_path is not None: + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config.refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + ).to(self.device_torch) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ).to(self.device_torch) + + self.refiner_unet = refiner.unet + del refiner + flush() + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, + ): + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if self.network is not None: + self.network.eval() + network = self.network + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set([x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and self.network.can_merge_in: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + noise_scheduler = self.noise_scheduler + if sampler is not None: + if sampler.startswith("sample_"): # sample_dpmpp_2m + # using ksampler + noise_scheduler = get_sampler( + 'lms', { + "prediction_type": self.prediction_type, + }) + else: + noise_scheduler = get_sampler( + sampler, + { + "prediction_type": self.prediction_type, + }, + 'sd' if not self.is_pixart else 'pixart' + ) + + try: + noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) + except: + pass + + if sampler.startswith("sample_") and self.is_xl: + # using kdiffusion + Pipe = StableDiffusionKDiffusionXLPipeline + elif self.is_xl: + Pipe = StableDiffusionXLPipeline + elif self.is_v3: + Pipe = StableDiffusion3Pipeline + else: + Pipe = StableDiffusionPipeline + + extra_args = {} + if self.adapter is not None: + if isinstance(self.adapter, T2IAdapter): + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter + elif isinstance(self.adapter, ControlNetModel): + if self.is_xl: + Pipe = StableDiffusionXLControlNetPipeline + else: + Pipe = StableDiffusionControlNetPipeline + extra_args['controlnet'] = self.adapter + elif isinstance(self.adapter, ReferenceAdapter): + # pass the noise scheduler to the adapter + self.adapter.noise_scheduler = noise_scheduler + else: + if self.is_xl: + extra_args['add_watermarker'] = False + + # TODO add clip skip + if self.is_xl: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ).to(self.device_torch) + pipeline.watermark = None + elif self.is_flux: + if self.model_config.use_flux_cfg: + pipeline = FluxWithCFGPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + + else: + pipeline = FluxPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + pipeline.watermark = None + elif self.is_v3: + pipeline = Pipe( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + text_encoder_3=self.text_encoder[2], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + tokenizer_3=self.tokenizer[2], + scheduler=noise_scheduler, + **extra_args + ) + elif self.is_pixart: + pipeline = PixArtSigmaPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + elif self.is_auraflow: + pipeline = AuraFlowPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + else: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + **extra_args + ) + flush() + # disable progress bar + pipeline.set_progress_bar_config(disable=True) + + if sampler.startswith("sample_"): + pipeline.set_scheduler(sampler) + + refiner_pipeline = None + if self.refiner_unet: + # build refiner pipeline + refiner_pipeline = StableDiffusionXLImg2ImgPipeline( + vae=pipeline.vae, + unet=self.refiner_unet, + text_encoder=None, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=None, + tokenizer_2=pipeline.tokenizer_2, + scheduler=pipeline.scheduler, + add_watermarker=False, + requires_aesthetics_score=True, + ).to(self.device_torch) + # refiner_pipeline.register_to_config(requires_aesthetics_score=False) + refiner_pipeline.watermark = None + refiner_pipeline.set_progress_bar_config(disable=True) + flush() + + start_multiplier = 1.0 + if self.network is not None: + start_multiplier = self.network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if self.network is not None: + assert self.network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if self.network is not None: + self.network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, + CustomAdapter) and validation_image is not None: + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values(extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError("Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + + if self.is_xl: + # fix guidance rescale for sdxl + # was trained on 0.7 (I believe) + + grs = gen_config.guidance_rescale + # if grs is None or grs < 0.00001: + # grs = 0.7 + # grs = 0.0 + + if sampler.startswith("sample_"): + extra['use_karras_sigmas'] = True + extra = { + **extra, + **gen_config.extra_kwargs, + } + + img = pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + # negative_prompt=gen_config.negative_prompt, + # negative_prompt_2=gen_config.negative_prompt_2, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_v3: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_flux: + if self.model_config.use_flux_cfg: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + else: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + # negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_pixart: + # needs attention masks for some reason + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_auraflow: + pipeline: AuraFlowPipeline = pipeline + + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + else: + img = pipeline( + # prompt=gen_config.prompt, + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # slide off just the last 1280 on the last dim as refiner does not use first text encoder + # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ + refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] + refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] + # run through refiner + img = refiner_pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + + # slice these as it does not use both text encoders + # height=gen_config.height, + # width=gen_config.width, + prompt_embeds=refiner_text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=refiner_unconditional_text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + denoising_start=gen_config.refiner_start_at, + denoising_end=gen_config.num_inference_steps, + image=img.unsqueeze(0), + generator=generator, + ).images[0] + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + flush() + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + if refiner_pipeline is not None: + del refiner_pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if self.network is not None: + self.network.train() + self.network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + num_channels = self.unet.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if self.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + if requires_aesthetic_score: + # refiner + # https://huggingface.co/papers/2307.01952 + aesthetic_score = 6.0 # simulate one + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + if self.is_xl: + with torch.no_grad(): + # 16, 6 for bs of 4 + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + + if do_classifier_free_guidance: + # todo check this with larget batches + add_time_ids = torch.cat([add_time_ids] * 2) + + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + # todo can we zero here the second text encoder? or match a blank string? + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + if self.model_config.refiner_name_or_path is not None: + # we have the refiner on the second half of everything. Do Both + if do_classifier_free_guidance: + raise ValueError("Refiner is not supported with classifier free guidance") + + if self.unet.training: + input_chunks = torch.chunk(latent_model_input, 2, dim=0) + timestep_chunks = torch.chunk(timestep, 2, dim=0) + added_cond_kwargs_chunked = { + "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), + "time_ids": torch.chunk(add_time_ids, 2, dim=0), + } + text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) + + # predict the noise residual + base_pred = self.unet( + input_chunks[0], + timestep_chunks[0], + encoder_hidden_states=text_embeds_chunks[0], + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], + "time_ids": added_cond_kwargs_chunked['time_ids'][0], + }, + **kwargs, + ).sample + + refiner_pred = self.refiner_unet( + input_chunks[1], + timestep_chunks[1], + encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], + # "time_ids": added_cond_kwargs_chunked['time_ids'][1], + "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + noise_pred = torch.cat([base_pred, refiner_pred], dim=0) + else: + noise_pred = self.refiner_unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": self.get_time_ids_from_latents(latent_model_input, + requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + else: + + # predict the noise residual + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + conditional_pred = noise_pred_text + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + else: + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.is_pixart: + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + batch_size, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + if self.pipeline.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.pipeline.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.pipeline.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.pipeline.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") + orig_height, orig_width = height, width + height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, + ratios=aspect_ratio_bin) + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.unet.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): + resolution = torch.tensor([height, width]).repeat(batch_size, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) + resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + encoder_hidden_states=text_embeddings.text_embeds, + encoder_attention_mask=text_embeddings.attention_mask, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + **kwargs + )[0] + + # learned sigma + if self.unet.config.out_channels // 2 == self.unet.config.in_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + else: + if self.unet.device != self.device_torch: + self.unet.to(self.device_torch) + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + if self.is_flux: + with torch.no_grad(): + + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) + + txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # todo make sure this doesnt change + timestep=timestep / 1000, # timestep is 1000 scale + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), + # [1, 512, 4096] + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] + txt_ids=txt_ids, # [1, 512, 3] + img_ids=img_ids, # [1, 4096, 3] + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=latent_model_input.shape[1], + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + elif self.is_v3: + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + elif self.is_auraflow: + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) + t = t.to(self.device_torch, self.torch_dtype) + + noise_pred = self.unet( + latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + timestep=t, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + if self.is_xl: + # todo make this a config + # 50% chance to use an encoder anyway even if it is disabled + # allows the other TE to compensate for the disabled one + # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 + # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 + use_encoder_1 = True + use_encoder_2 = True + + return PromptEmbeds( + train_tools.encode_prompts_xl( + self.tokenizer, + self.text_encoder, + prompt, + prompt2, + num_images_per_prompt=num_images_per_prompt, + use_text_encoder_1=use_encoder_1, + use_text_encoder_2=use_encoder_2, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + ) + ) + if self.is_v3: + return PromptEmbeds( + train_tools.encode_prompts_sd3( + self.tokenizer, + self.text_encoder, + prompt, + num_images_per_prompt=num_images_per_prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + pipeline=self.pipeline, + ) + ) + elif self.is_pixart: + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=300 if self.model_config.is_pixart_sigma else 120, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, + ) + elif self.is_auraflow: + embeds, attention_mask = train_tools.encode_prompts_auraflow( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, # not used + ) + elif self.is_flux: + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, # list + self.text_encoder, # list + prompt, + truncate=not long_prompts, + max_length=512, + dropout_prob=dropout_prob, + attn_mask=self.model_config.attn_masking + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + + elif isinstance(self.text_encoder, T5EncoderModel): + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + + # just mask the attention mask + prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) + embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) + return PromptEmbeds( + embeds, + + # do we want attn mask here? + # attention_mask=attention_mask, + ) + else: + return PromptEmbeds( + train_tools.encode_prompts( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob + ) + ) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux: + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in key for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in key for s in self.model_config.only_if_contains]): + del named_params[key] + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): + + # load the full refiner since we only train unet + if self.model_config.refiner_name_or_path is None: + raise ValueError("Refiner must be specified to save it") + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config._original_refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device='cpu', + # variant="fp16", + use_safetensors=True, + ) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device='cpu', + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ) + # replace original unet + refiner.unet = self.refiner_unet + flush() + + diffusers_state_dict = OrderedDict() + for k, v in refiner.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.text_encoder_2.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + diffusers_state_dict[new_key] = v + + converted_state_dict = get_ldm_state_dict_from_diffusers( + diffusers_state_dict, + 'sdxl_refiner', + device='cpu', + dtype=save_dtype + ) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + version_string = '1' + if self.is_v2: + version_string = '2' + if self.is_xl: + version_string = 'sdxl' + if self.is_ssd: + # overwrite sdxl because both wil be true here + version_string = 'ssd' + if self.is_ssd and self.is_vega: + version_string = 'vega' + # if output file does not end in .safetensors, then it is a directory and we are + # saving in diffusers format + if not output_file.endswith('.safetensors'): + # diffusers + if self.is_flux: + # only save the unet + transformer: FluxTransformer2DModel = self.unet + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + else: + + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_file, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + else: + save_ldm_model_from_diffusers( + sd=self, + output_file=output_file, + meta=meta, + save_dtype=save_dtype, + sd_version=version_string, + ) + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + if self.is_pixart or self.is_auraflow or self.is_flux: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print(f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + unet_has_grad = self.unet.proj_out.weight.requires_grad + else: + unet_has_grad = self.unet.conv_in.weight.requires_grad + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + try: + te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad + except: + te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): + te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + else: + te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_(state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs) diff --git a/toolkit/style.py b/toolkit/style.py new file mode 100644 index 0000000000000000000000000000000000000000..26ac33fa710b3286323357abc50b13e9bcda9aec --- /dev/null +++ b/toolkit/style.py @@ -0,0 +1,232 @@ +from torch import nn +import torch.nn.functional as F +import torch +from torchvision import models + + +# device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def tensor_size(tensor): + channels = tensor.shape[1] + height = tensor.shape[2] + width = tensor.shape[3] + return channels * height * width + +class ContentLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(ContentLoss, self).__init__() + self.single_target = single_target + self.device = device + self.loss = None + + def forward(self, stacked_input): + + if self.single_target: + split_size = stacked_input.size()[0] // 2 + pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0) + + content_size = tensor_size(pred_layer) + + # Define the separate loss function + def separated_loss(y_pred, y_true): + y_pred = y_pred.float() + y_true = y_true.float() + diff = torch.abs(y_pred - y_true) + l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0 + return 2. * l2 / content_size + + # Calculate itemized loss + pred_itemized_loss = separated_loss(pred_layer, target_layer) + # check if is nan + if torch.isnan(pred_itemized_loss).any(): + print('pred_itemized_loss is nan') + + # Calculate the mean of itemized loss + loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True) + self.loss = loss + + return stacked_input + + +def convert_to_gram_matrix(inputs): + inputs = inputs.float() + shape = inputs.size() + batch, filters, height, width = shape[0], shape[1], shape[2], shape[3] + size = height * width * filters + + feats = inputs.view(batch, filters, height * width) + feats_t = feats.transpose(1, 2) + grams_raw = torch.matmul(feats, feats_t) + gram_matrix = grams_raw / size + + return gram_matrix + + +###################################################################### +# Now the style loss module looks almost exactly like the content loss +# module. The style distance is also computed using the mean square +# error between :math:`G_{XL}` and :math:`G_{SL}`. +# + +class StyleLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(StyleLoss, self).__init__() + self.single_target = single_target + self.device = device + + def forward(self, stacked_input): + input_dtype = stacked_input.dtype + stacked_input = stacked_input.float() + if self.single_target: + split_size = stacked_input.size()[0] // 2 + preds, style_target = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + preds, style_target, _ = torch.split(stacked_input, split_size, dim=0) + + def separated_loss(y_pred, y_true): + gram_size = y_true.size(1) * y_true.size(2) + sum_axis = (1, 2) + diff = torch.abs(y_pred - y_true) + raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True) + return raw_loss / gram_size + + target_grams = convert_to_gram_matrix(style_target) + pred_grams = convert_to_gram_matrix(preds) + itemized_loss = separated_loss(pred_grams, target_grams) + # check if is nan + if torch.isnan(itemized_loss).any(): + print('itemized_loss is nan') + # reshape itemized loss to be (batch, 1, 1, 1) + itemized_loss = torch.unsqueeze(itemized_loss, dim=1) + # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2]) + loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True) + self.loss = loss.to(input_dtype).float() + return stacked_input.to(input_dtype) + + +# create a module to normalize input image so we can easily put it in a +# ``nn.Sequential`` +class Normalization(nn.Module): + def __init__(self, device, dtype=torch.float32): + super(Normalization, self).__init__() + mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.dtype = dtype + # .view the mean and std to make them [C x 1 x 1] so that they can + # directly work with image Tensor of shape [B x C x H x W]. + # B is batch size. C is number of channels. H is height and W is width. + self.mean = torch.tensor(mean).view(-1, 1, 1) + self.std = torch.tensor(std).view(-1, 1, 1) + + def forward(self, stacked_input): + # cast to float 32 if not already # only necessary when processing gram matrix + # if stacked_input.dtype != torch.float32: + # stacked_input = stacked_input.float() + # remove alpha channel if it exists + if stacked_input.shape[1] == 4: + stacked_input = stacked_input[:, :3, :, :] + # normalize to min and max of 0 - 1 + in_min = torch.min(stacked_input) + in_max = torch.max(stacked_input) + # norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) + # return (norm_stacked_input - self.mean) / self.std + return ((stacked_input - self.mean) / self.std).to(self.dtype) + + +class OutputLayer(nn.Module): + def __init__(self, name='output_layer'): + super(OutputLayer, self).__init__() + self.name = name + self.tensor = None + + def forward(self, stacked_input): + self.tensor = stacked_input + return stacked_input + + +def get_style_model_and_losses( + single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code + device='cuda' if torch.cuda.is_available() else 'cpu', + output_layer_name=None, + dtype=torch.float32 +): + # content_layers = ['conv_4'] + # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] + content_layers = ['conv2_2', 'conv3_2', 'conv4_2'] + style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] + cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval() + # set all weights in the model to our dtype + # for layer in cnn.children(): + # layer.to(dtype=dtype) + + # normalization module + normalization = Normalization(device, dtype=dtype).to(device) + + # just in order to have an iterable access to or list of content/style + # losses + content_losses = [] + style_losses = [] + + # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential`` + # to put in modules that are supposed to be activated sequentially + model = nn.Sequential(normalization) + + i = 0 # increment every time we see a conv + block = 1 + children = list(cnn.children()) + + output_layer = None + + for layer in children: + if isinstance(layer, nn.Conv2d): + i += 1 + name = f'conv{block}_{i}_raw' + elif isinstance(layer, nn.ReLU): + # name = 'relu_{}'.format(i) + name = f'conv{block}_{i}' # target this + # The in-place version doesn't play very nicely with the ``ContentLoss`` + # and ``StyleLoss`` we insert below. So we replace with out-of-place + # ones here. + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = 'pool_{}'.format(i) + block += 1 + i = 0 + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(i) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + + model.add_module(name, layer) + + if name in content_layers: + # add content loss: + content_loss = ContentLoss(single_target=single_target, device=device) + model.add_module("content_loss_{}_{}".format(block, i), content_loss) + content_losses.append(content_loss) + + if name in style_layers: + # add style loss: + style_loss = StyleLoss(single_target=single_target, device=device) + model.add_module("style_loss_{}_{}".format(block, i), style_loss) + style_losses.append(style_loss) + + if output_layer_name is not None and name == output_layer_name: + output_layer = OutputLayer(name) + model.add_module("output_layer_{}_{}".format(block, i), output_layer) + + # now we trim off the layers after the last content and style losses + for i in range(len(model) - 1, -1, -1): + if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer): + break + + model = model[:(i + 1)] + model.to(dtype=dtype) + + return model, style_losses, content_losses, output_layer diff --git a/toolkit/timer.py b/toolkit/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4fecba1321aa856808bd7a3290511882b84627 --- /dev/null +++ b/toolkit/timer.py @@ -0,0 +1,65 @@ +import time +from collections import OrderedDict, deque + + +class Timer: + def __init__(self, name='Timer', max_buffer=10): + self.name = name + self.max_buffer = max_buffer + self.timers = OrderedDict() + self.active_timers = {} + self.current_timer = None # Used for the context manager functionality + + def start(self, timer_name): + if timer_name not in self.timers: + self.timers[timer_name] = deque(maxlen=self.max_buffer) + self.active_timers[timer_name] = time.time() + + def cancel(self, timer_name): + """Cancel an active timer.""" + if timer_name in self.active_timers: + del self.active_timers[timer_name] + + def stop(self, timer_name): + if timer_name not in self.active_timers: + raise ValueError(f"Timer '{timer_name}' was not started!") + + elapsed_time = time.time() - self.active_timers[timer_name] + self.timers[timer_name].append(elapsed_time) + + # Clean up active timers + del self.active_timers[timer_name] + + # Check if this timer's buffer exceeds max_buffer and remove the oldest if it does + if len(self.timers[timer_name]) > self.max_buffer: + self.timers[timer_name].popleft() + + def print(self): + print(f"\nTimer '{self.name}':") + # sort by longest at top + for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True): + avg_time = sum(timings) / len(timings) + print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}") + + print('') + + def reset(self): + self.timers.clear() + self.active_timers.clear() + + def __call__(self, timer_name): + """Enable the use of the Timer class as a context manager.""" + self.current_timer = timer_name + self.start(timer_name) + return self + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + # No exceptions, stop the timer normally + self.stop(self.current_timer) + else: + # There was an exception, cancel the timer + self.cancel(self.current_timer) diff --git a/toolkit/train_pipelines.py b/toolkit/train_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cc623cd55e802bcad1de41cd90be6a57d2743a --- /dev/null +++ b/toolkit/train_pipelines.py @@ -0,0 +1,316 @@ +from typing import Optional, Tuple, Callable, Dict, Any, Union, List + +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.pipelines import CustomStableDiffusionXLPipeline + + +class TransferStableDiffusionXLPipeline(CustomStableDiffusionXLPipeline): + def transfer_diffuse( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + target_unet: Optional[torch.nn.Module] = None, + pre_condition_callback = None, + each_step_callback = None, + network: Optional[LoRASpecialNetwork] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + conditioned_noise_pred, conditioned_latent_model_input = pre_condition_callback( + noise_pred.clone().detach(), + latent_model_input.clone().detach(), + ) + + # start grad + with torch.enable_grad(): + with network: + assert network.is_active + noise_train_pred = target_unet( + conditioned_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + each_step_callback(conditioned_noise_pred, noise_train_pred) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..592f7cc681c51e653602883436c5750cb53c8873 --- /dev/null +++ b/toolkit/train_tools.py @@ -0,0 +1,768 @@ +import argparse +import hashlib +import json +import os +import time +from typing import TYPE_CHECKING, Union, List +import sys + +from torch.cuda.amp import GradScaler + +from toolkit.paths import SD_SCRIPTS_ROOT + +sys.path.append(SD_SCRIPTS_ROOT) + +from diffusers import ( + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler +) +import torch +import re +from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel + +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL +TEXT_ENCODER_2_PROJECTION_DIM = 1280 +UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 + + +def get_torch_dtype(dtype_str): + # if it is a torch dtype, return it + if isinstance(dtype_str, torch.dtype): + return dtype_str + if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": + return torch.float + if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": + return torch.float16 + if dtype_str == "bf16" or dtype_str == "bfloat16": + return torch.bfloat16 + if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8": + return torch.float8_e4m3fn + return dtype_str + + +def replace_filewords_prompt(prompt, args: argparse.Namespace): + # if name_replace attr in args (may not be) + if hasattr(args, "name_replace") and args.name_replace is not None: + # replace [name] to args.name_replace + prompt = prompt.replace("[name]", args.name_replace) + if hasattr(args, "prepend") and args.prepend is not None: + # prepend to every item in prompt file + prompt = args.prepend + ' ' + prompt + if hasattr(args, "append") and args.append is not None: + # append to every item in prompt file + prompt = prompt + ' ' + args.append + return prompt + + +def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace): + # if name_replace attr in args (may not be) + if hasattr(args, "name_replace") and args.name_replace is not None: + if not len(dataset_group.image_data) > 0: + # throw error + raise ValueError("dataset_group.image_data is empty") + for key in dataset_group.image_data: + dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace( + "[name]", args.name_replace) + + return dataset_group + + +def get_seeds_from_latents(latents): + # latents shape = (batch_size, 4, height, width) + # for speed we only use 8x8 slice of the first channel + seeds = [] + + # split batch up + for i in range(latents.shape[0]): + # use only first channel, multiply by 255 and convert to int + tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width) + # slice 8x8 + tensor = tensor[:8, :8] + # clip to 0-255 + tensor = torch.clamp(tensor, 0, 255) + # convert to 8bit int + tensor = tensor.to(torch.uint8) + # convert to bytes + tensor_bytes = tensor.cpu().numpy().tobytes() + # hash + hash_object = hashlib.sha256(tensor_bytes) + # get hex + hex_dig = hash_object.hexdigest() + # convert to int + seed = int(hex_dig, 16) % (2 ** 32) + # append + seeds.append(seed) + return seeds + + +def get_noise_from_latents(latents): + seed_list = get_seeds_from_latents(latents) + noise = [] + for seed in seed_list: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + noise.append(torch.randn_like(latents[0])) + return torch.stack(noise) + + +# mix 0 is completely noise mean, mix 1 is completely target mean + +def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None): + dim = dim or (1, 2, 3) + # reduce mean of noise on dim 2, 3, keeping 0 and 1 intact + noise_mean = noise.mean(dim=dim, keepdim=True) + target_mean = target.mean(dim=dim, keepdim=True) + + new_noise_mean = mix * target_mean + (1 - mix) * noise_mean + + noise = noise - noise_mean + new_noise_mean + return noise + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(noise, noise_offset): + if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001): + return noise + noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) + return noise + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import PromptEmbeds + + +def concat_prompt_embeddings( + unconditional: 'PromptEmbeds', + conditional: 'PromptEmbeds', + n_imgs: int, +): + from toolkit.stable_diffusion_model import PromptEmbeds + text_embeds = torch.cat( + [unconditional.text_embeds, conditional.text_embeds] + ).repeat_interleave(n_imgs, dim=0) + pooled_embeds = None + if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None: + pooled_embeds = torch.cat( + [unconditional.pooled_embeds, conditional.pooled_embeds] + ).repeat_interleave(n_imgs, dim=0) + return PromptEmbeds([text_embeds, pooled_embeds]) + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +if TYPE_CHECKING: + from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + + +def text_tokenize( + tokenizer: 'CLIPTokenizer', + prompts: list[str], + truncate: bool = True, + max_length: int = None, + max_length_multiplier: int = 4, +): + # allow fo up to 4x the max length for long prompts + if max_length is None: + if truncate: + max_length = tokenizer.model_max_length + else: + # allow up to 4x the max length for long prompts + max_length = tokenizer.model_max_length * max_length_multiplier + + input_ids = tokenizer( + prompts, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + if truncate or max_length == tokenizer.model_max_length: + return input_ids + else: + # remove additional padding + num_chunks = input_ids.shape[1] // tokenizer.model_max_length + chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1) + + # New list to store non-redundant chunks + non_redundant_chunks = [] + + for chunk in chunks: + if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element + non_redundant_chunks.append(chunk) + + input_ids = torch.cat(non_redundant_chunks, dim=1) + return input_ids + + +# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 +def text_encode_xl( + text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'], + tokens: torch.FloatTensor, + num_images_per_prompt: int = 1, + max_length: int = 77, # not sure what default to put here, always pass one? + truncate: bool = True, +): + if truncate: + # normal short prompt 77 tokens max + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + else: + # handle long prompts + prompt_embeds_list = [] + tokens = tokens.to(text_encoder.device) + pooled_prompt_embeds = None + for i in range(0, tokens.shape[-1], max_length): + # todo run it through the in a single batch + section_tokens = tokens[:, i: i + max_length] + embeds = text_encoder(section_tokens, output_hidden_states=True) + pooled_prompt_embed = embeds[0] + if pooled_prompt_embeds is None: + # we only want the first ( I think??) + pooled_prompt_embeds = pooled_prompt_embed + prompt_embed = embeds.hidden_states[-2] # always penultimate layer + prompt_embeds_list.append(prompt_embed) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompts_xl( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']], + prompts: list[str], + prompts2: Union[list[str], None], + num_images_per_prompt: int = 1, + use_text_encoder_1: bool = True, # sdxl + use_text_encoder_2: bool = True, # sdxl + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +) -> tuple[torch.FloatTensor, torch.FloatTensor]: + # text_encoder and text_encoder_2's penuultimate layer's output + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + if prompts2 is None: + prompts2 = prompts + + for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): + # todo, we are using a blank string to ignore that encoder for now. + # find a better way to do this (zeroing?, removing it from the unet?) + prompt_list_to_use = prompts if idx == 0 else prompts2 + if idx == 0 and not use_text_encoder_1: + prompt_list_to_use = ["" for _ in prompts] + if idx == 1 and not use_text_encoder_2: + prompt_list_to_use = ["" for _ in prompts] + + if dropout_prob > 0.0: + # randomly drop out prompts + prompt_list_to_use = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use + ] + + text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) + # set the max length for the next one + if idx == 0: + max_length = text_tokens_input_ids.shape[-1] + + text_embeds, pooled_text_embeds = text_encode_xl( + text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length, + truncate=truncate + ) + + text_embeds_list.append(text_embeds) + + bs_embed = pooled_text_embeds.shape[0] + pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds + +def encode_prompts_sd3( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]], + prompts: list[str], + num_images_per_prompt: int = 1, + truncate: bool = True, + max_length=None, + dropout_prob=0.0, + pipeline = None, +): + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + prompt_2 = prompts + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompts + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + device = text_encoders[0].device + + prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds( + prompt=prompts, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = pipeline._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + device=device + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + + +# ref for long prompts https://github.com/huggingface/diffusers/issues/2136 +def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): + if max_length is None and not truncate: + raise ValueError("max_length must be set if truncate is True") + try: + tokens = tokens.to(text_encoder.device) + except Exception as e: + print(e) + print("tokens.device", tokens.device) + print("text_encoder.device", text_encoder.device) + raise e + + if truncate: + return text_encoder(tokens)[0] + else: + # handle long prompts + prompt_embeds_list = [] + for i in range(0, tokens.shape[-1], max_length): + prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0] + prompt_embeds_list.append(prompt_embeds) + + return torch.cat(prompt_embeds_list, dim=1) + + +def encode_prompts( + tokenizer: 'CLIPTokenizer', + text_encoder: 'CLIPTextModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = tokenizer.model_max_length + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) + text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length) + + return text_embeddings + + +def encode_prompts_pixart( + tokenizer: 'T5Tokenizer', + text_encoder: 'T5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + # See Section 3.1. of the paper. + max_length = 120 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + text_inputs = tokenizer( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_encoder.device) + + text_input_ids = text_input_ids.to(text_encoder.device) + + prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + + return prompt_embeds.last_hidden_state, prompt_attention_mask + + +def encode_prompts_auraflow( + tokenizer: 'T5Tokenizer', + text_encoder: 'UMT5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = 256 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder.device + + text_inputs = tokenizer( + prompts, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + return prompt_embeds, prompt_attention_mask + +def encode_prompts_flux( + tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']], + text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']], + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, + attn_mask: bool = False, +): + if max_length is None: + max_length = 512 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder[0].device + dtype = text_encoder[0].dtype + + batch_size = len(prompts) + + # clip + text_inputs = tokenizer[0]( + prompts, + padding="max_length", + max_length=tokenizer[0].model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + pooled_prompt_embeds = prompt_embeds.pooler_output + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) + + # T5 + text_inputs = tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + if attn_mask: + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + return prompt_embeds, pooled_prompt_embeds + + +# for XL +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, +): + if dynamic_crops: + # random float scale between 1 and 3 + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + # random position + crops_coords_top_left = ( + torch.randint(0, original_size[0] - height, (1,)).item(), + torch.randint(0, original_size[1] - width, (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + # this is expected as 6 + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # this is expected as 2816 + passed_add_embed_dim = ( + UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 + ) + if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: + raise ValueError( + f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + +def concat_embeddings( + unconditional: torch.FloatTensor, + conditional: torch.FloatTensor, + n_imgs: int, +): + return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) + + +def add_all_snr_to_noise_scheduler(noise_scheduler, device): + try: + if hasattr(noise_scheduler, "all_snr"): + return + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + except Exception as e: + # just move on + pass + + +def get_all_snr(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return noise_scheduler.all_snr.to(device) + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + return all_snr.to(device) + +class LearnableSNRGamma: + """ + This is a trainer for learnable snr gamma + It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps + """ + def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): + self.device = device + self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler + self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device)) + self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device)) + self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device)) + self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device)) + self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01) + self.buffer = [] + self.max_buffer_size = 20 + + def forward(self, loss, timesteps): + # do a our train loop for lsnr here and return our values detached + loss = loss.detach() + with torch.no_grad(): + loss_chunks = torch.chunk(loss, loss.shape[0], dim=0) + for loss_chunk in loss_chunks: + self.buffer.append(loss_chunk.mean().detach()) + if len(self.buffer) > self.max_buffer_size: + self.buffer.pop(0) + all_snr = get_all_snr(self.noise_scheduler, loss.device) + snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device) + base_snrs = snr.clone().detach() + snr.requires_grad = True + snr = (snr + self.offset_1) * self.scale + self.offset_2 + + gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + with torch.no_grad(): + target = torch.mean(torch.stack(self.buffer)).detach() + + # local_loss = torch.mean(torch.abs(snr_adjusted_loss - target)) + squared_differences = (snr_adjusted_loss - target) ** 2 + local_loss = torch.mean(squared_differences) + local_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach() + + +def apply_learnable_snr_gos( + loss, + timesteps, + learnable_snr_trainer: LearnableSNRGamma +): + + snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps) + + snr = (snr + offset_1) * scale + offset_2 + + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss + + +def apply_snr_weight( + loss, + timesteps, + noise_scheduler: Union['DDPMScheduler'], + gamma, + fixed=False, +): + # will get it from noise scheduler if exist or will calculate it if not + all_snr = get_all_snr(noise_scheduler, loss.device) + # step_indices = [] + # for t in timesteps: + # for i, st in enumerate(noise_scheduler.timesteps): + # if st == t: + # step_indices.append(i) + # break + # this breaks on some schedulers + # step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps] + + offset = 0 + if noise_scheduler.timesteps[0] == 1000: + offset = 1 + snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + if fixed: + snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr + else: + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss + + +def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler): + mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_output.shape[0]): + sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, + dtype=model_output.dtype, device=model_output.device) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] + out_chunks.append(out) + return torch.cat(out_chunks, dim=0) diff --git a/toolkit/util/inverse_cfg.py b/toolkit/util/inverse_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c85544a95c1a81cd5f7f6e4cf9ca3408e92c81e --- /dev/null +++ b/toolkit/util/inverse_cfg.py @@ -0,0 +1,25 @@ +import torch + + +def inverse_classifier_guidance( + noise_pred_cond: torch.Tensor, + noise_pred_uncond: torch.Tensor, + guidance_scale: torch.Tensor +): + """ + Adjust the noise_pred_cond for the classifier free guidance algorithm + to ensure that the final noise prediction equals the original noise_pred_cond. + """ + # To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond + # based on the formula used in the algorithm. + # We derive the formula to find the correct adjustment for noise_pred_cond: + # noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1) + # It's important to check if guidance_scale is not 1 to avoid division by zero. + if guidance_scale == 1: + # If guidance_scale is 1, adjusting is not needed or possible in the same way, + # since it would lead to division by zero. This also means the algorithm inherently + # doesn't alter the noise_pred_cond in relation to noise_pred_uncond. + # Thus, we return the original values, though this situation might need special handling. + return noise_pred_cond + adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale + return adjusted_noise_pred_cond